trainflux / app.py
uruguayai's picture
Update app.py
399bb13 verified
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
from diffusers.schedulers import FlaxPNDMScheduler
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np
from inspect import signature
# Custom Scheduler
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
def add_noise(self, state, original_samples, noise, timesteps):
timesteps = timesteps.astype(jnp.int32)
return super().add_noise(state, original_samples, noise, timesteps)
# Set up cache directories
cache_dir = "/tmp/huggingface_cache"
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)
print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")
def filter_dict(dict_to_filter, target_callable):
"""Filter a dictionary to only include keys that are valid parameters for the target callable."""
valid_params = signature(target_callable).parameters.keys()
return {k: v for k, v in dict_to_filter.items() if k in valid_params}
# Function to load or download the model
def get_model(model_id, revision):
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
print(f"Model cache file: {model_cache_file}")
if os.path.exists(model_cache_file):
print("Loading model from cache...")
with open(model_cache_file, 'rb') as f:
return pickle.load(f)
else:
print("Downloading model...")
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision=revision,
dtype=jnp.float32,
)
with open(model_cache_file, 'wb') as f:
pickle.dump((pipeline, params), f)
return pipeline, params
# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")
# Use custom scheduler
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler = custom_scheduler
# Extract UNet from pipeline
unet = pipeline.unet
# Print UNet configuration
print("UNet configuration:")
print(unet.config)
# Adjust the input layer of the UNet
def adjust_unet_input_layer(params):
if 'unet' in params:
unet_params = params['unet']
else:
unet_params = params
if 'conv_in' not in unet_params:
print("Warning: 'conv_in' not found in UNet params. Skipping input layer adjustment.")
return params
conv_in_weight = unet_params['conv_in']['kernel']
print(f"Original conv_in weight shape: {conv_in_weight.shape}")
if conv_in_weight.shape[2] != 4:
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
unet_params['conv_in']['kernel'] = new_conv_in_weight
print(f"New conv_in weight shape: {unet_params['conv_in']['kernel'].shape}")
if 'unet' in params:
params['unet'] = unet_params
else:
params = unet_params
return params
params = adjust_unet_input_layer(params)
# Load and preprocess your dataset
def preprocess_images(examples):
def process_image(image):
if isinstance(image, str):
if not image.lower().endswith('.jpg') and not image.lower().endswith('.jpeg'):
return None
image = Image.open(image)
if not isinstance(image, Image.Image):
return None
image = image.convert("RGB").resize((512, 512))
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(2, 0, 1) # Change to channel-first format
return image
processed = [process_image(img) for img in examples["image"]]
return {"pixel_values": [img for img in processed if img is not None]}
# Load dataset from Hugging Face
dataset_name = "uruguayai/montevideo"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
print(f"Dataset name: {dataset_name}")
print(f"Dataset cache file: {dataset_cache_file}")
if os.path.exists(dataset_cache_file):
print("Loading dataset from cache...")
with open(dataset_cache_file, 'rb') as f:
processed_dataset = pickle.load(f)
else:
print("Processing dataset...")
dataset = load_dataset(dataset_name)
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
processed_dataset = processed_dataset.filter(lambda example: len(example['pixel_values']) > 0)
with open(dataset_cache_file, 'wb') as f:
pickle.dump(processed_dataset, f)
print(f"Processed dataset size: {len(processed_dataset)}")
# Print sample input shape
sample_batch = next(iter(processed_dataset.batch(1)))
print(f"Sample batch keys: {sample_batch.keys()}")
print(f"Sample pixel_values type: {type(sample_batch['pixel_values'])}")
print(f"Sample pixel_values length: {len(sample_batch['pixel_values'])}")
if len(sample_batch['pixel_values']) > 0:
print(f"Sample pixel_values[0] shape: {np.array(sample_batch['pixel_values'][0]).shape}")
# Training function
def train_step(state, batch, rng):
def compute_loss(unet_params, pixel_values, rng):
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
if pixel_values.ndim == 3:
pixel_values = jnp.expand_dims(pixel_values, axis=0) # Add batch dimension if needed
print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
# Use the VAE from params
latents = pipeline.vae.apply(
{"params": params["vae"]},
pixel_values,
method=pipeline.vae.encode
).latent_dist.sample(rng)
latents = latents * jnp.float32(0.18215)
print(f"latents shape: {latents.shape}")
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
timesteps = jax.random.randint(
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
)
noisy_latents = pipeline.scheduler.add_noise(
pipeline.scheduler.create_state(),
original_samples=latents,
noise=noise,
timesteps=timesteps
)
encoder_hidden_states = jax.random.normal(
rng,
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
dtype=jnp.float32
)
print(f"noisy_latents shape: {noisy_latents.shape}")
print(f"timesteps shape: {timesteps.shape}")
print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
# Use the state's apply_fn (which should be the adjusted UNet)
model_output = state.apply_fn(
{"params": unet_params},
noisy_latents,
jnp.array(timesteps, dtype=jnp.int32),
encoder_hidden_states,
train=True,
).sample
return jnp.mean((model_output - noise) ** 2)
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
rng, step_rng = jax.random.split(rng)
# Ensure we're passing the correct structure to grad_fn and compute_loss
unet_params = state.params["params"] if "params" in state.params else state.params
grads = grad_fn(unet_params, batch["pixel_values"], step_rng)
loss = compute_loss(unet_params, batch["pixel_values"], step_rng)
# Update the state with the correct structure
new_params = optax.apply_updates(state.params, grads)
state = state.replace(params=new_params)
return state, loss
# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
# Create a new UNet with the adjusted parameters
unet_config = dict(unet.config)
filtered_unet_config = filter_dict(unet_config, FlaxUNet2DConditionModel.__init__)
print("Filtered UNet config keys:", filtered_unet_config.keys())
adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
# Adjust the state creation
state = train_state.TrainState.create(
apply_fn=adjusted_unet.apply,
params={"params": adjusted_params}, # Wrap params in a dict with "params" key
tx=optimizer,
)
# Training loop
num_epochs = 3
batch_size = 1
rng = jax.random.PRNGKey(0)
# Training loop
num_epochs = 3
batch_size = 1
rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
epoch_loss = 0
num_batches = 0
num_errors = 0
for batch in tqdm(processed_dataset.batch(batch_size)):
try:
batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
rng, step_rng = jax.random.split(rng)
state, loss = train_step(state, batch, step_rng)
epoch_loss += loss
num_batches += 1
if num_batches % 10 == 0:
jax.clear_caches()
print(f"Processed {num_batches} batches. Current loss: {loss}")
except Exception as e:
num_errors += 1
print(f"Error processing batch: {e}")
continue
if num_batches > 0:
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}, Errors: {num_errors}")
else:
print(f"Epoch {epoch+1}/{num_epochs}, No valid batches processed, Errors: {num_errors}")
jax.clear_caches()
# Save the fine-tuned model
output_dir = "/tmp/montevideo_fine_tuned_model"
os.makedirs(output_dir, exist_ok=True)
adjusted_unet.save_pretrained(output_dir, params=state.params["params"])
print(f"Model saved to {output_dir}")