File size: 6,714 Bytes
de0db89 3518b5f de0db89 3518b5f de0db89 35bc545 3518b5f acc7f4b de0db89 77248af de0db89 1f8900f de0db89 3518b5f de0db89 3518b5f acc7f4b 3518b5f acc7f4b de0db89 4434e29 920c999 cc5a61c de0db89 77248af de0db89 77248af de0db89 3518b5f 649234d 3518b5f de0db89 920c999 cc5a61c acc7f4b 8e214b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline
from diffusers.schedulers import FlaxPNDMScheduler
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np
# Custom Scheduler
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
def add_noise(self, state, original_samples, noise, timesteps):
# Explicitly cast timesteps to int32
timesteps = timesteps.astype(jnp.int32)
return super().add_noise(state, original_samples, noise, timesteps)
# Set up cache directories
cache_dir = "/tmp/huggingface_cache"
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)
print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")
# Function to load or download the model
def get_model(model_id, revision):
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
print(f"Model cache file: {model_cache_file}")
if os.path.exists(model_cache_file):
print("Loading model from cache...")
with open(model_cache_file, 'rb') as f:
return pickle.load(f)
else:
print("Downloading model...")
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision=revision,
dtype=jnp.float32,
)
with open(model_cache_file, 'wb') as f:
pickle.dump((pipeline, params), f)
return pipeline, params
# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")
# Use custom scheduler
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler = custom_scheduler
# Extract UNet from pipeline
unet = pipeline.unet
# Load and preprocess your dataset
def preprocess_images(examples):
def process_image(image):
if isinstance(image, str):
image = Image.open(image)
if not isinstance(image, Image.Image):
raise ValueError(f"Unexpected image type: {type(image)}")
image = image.convert("RGB").resize((512, 512))
image = np.array(image).astype(np.float32) / 255.0
return image.transpose(2, 0, 1)
return {"pixel_values": [process_image(img) for img in examples["image"]]}
# Load dataset from Hugging Face
dataset_name = "uruguayai/montevideo"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
print(f"Dataset name: {dataset_name}")
print(f"Dataset cache file: {dataset_cache_file}")
if os.path.exists(dataset_cache_file):
print("Loading dataset from cache...")
with open(dataset_cache_file, 'rb') as f:
processed_dataset = pickle.load(f)
else:
print("Processing dataset...")
dataset = load_dataset(dataset_name)
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
with open(dataset_cache_file, 'wb') as f:
pickle.dump(processed_dataset, f)
print(f"Processed dataset size: {len(processed_dataset)}")
# Training function
def train_step(state, batch, rng):
def compute_loss(params, pixel_values, rng):
print("pixel_values dtype:", pixel_values.dtype)
print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
print("rng dtype:", rng.dtype)
# Ensure pixel_values are float32
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
# Encode images to latent space
latents = pipeline.vae.apply(
{"params": params["vae"]},
pixel_values,
method=pipeline.vae.encode
).latent_dist.sample(rng)
latents = latents * jnp.float32(0.18215)
# Generate random noise
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
# Sample random timesteps
timesteps = jax.random.randint(
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
)
print("timesteps dtype:", timesteps.dtype)
print("latents dtype:", latents.dtype)
print("noise dtype:", noise.dtype)
# Add noise to latents
noisy_latents = pipeline.scheduler.add_noise(
pipeline.scheduler.create_state(),
original_samples=latents,
noise=noise,
timesteps=timesteps
)
# Generate random encoder hidden states (simulating text embeddings)
encoder_hidden_states = jax.random.normal(
rng,
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
dtype=jnp.float32
)
# Predict noise
model_output = state.apply_fn(
{'params': params["unet"]},
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
train=True,
)
# Compute loss
return jnp.mean((model_output - noise) ** 2)
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
rng, step_rng = jax.random.split(rng)
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
state = state.apply_gradients(grads=grads)
return state, loss
# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
float32_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
state = train_state.TrainState.create(
apply_fn=unet.__call__,
params=float32_params,
tx=optimizer,
)
# Training loop
num_epochs = 3
batch_size = 1
rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
epoch_loss = 0
num_batches = 0
for batch in tqdm(processed_dataset.batch(batch_size)):
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
rng, step_rng = jax.random.split(rng)
state, loss = train_step(state, batch, step_rng)
epoch_loss += loss
num_batches += 1
if num_batches % 10 == 0:
jax.clear_caches()
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
jax.clear_caches()
# Save the fine-tuned model
output_dir = "/tmp/montevideo_fine_tuned_model"
os.makedirs(output_dir, exist_ok=True)
unet.save_pretrained(output_dir, params=state.params["unet"])
print(f"Model saved to {output_dir}") |