trainflux / app.py
uruguayai's picture
Update app.py
cf50961 verified
raw
history blame
8.46 kB
import jax
import jax.numpy as jnp
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np
import gc
# Force JAX to use CPU
jax.config.update('jax_platform_name', 'cpu')
print("Using CPU for computations")
# Set up cache directories
cache_dir = "/tmp/huggingface_cache"
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)
print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")
# Function to load or download the model
def get_model(model_id, revision):
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
print(f"Model cache file: {model_cache_file}")
if os.path.exists(model_cache_file):
print("Loading model from cache...")
with open(model_cache_file, 'rb') as f:
return pickle.load(f)
else:
print("Downloading model...")
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision=revision,
dtype=jnp.float32, # Use float32 for CPU
)
with open(model_cache_file, 'wb') as f:
pickle.dump((pipeline, params), f)
return pipeline, params
# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")
# Extract UNet from pipeline
unet = pipeline.unet
# Load and preprocess your dataset
def preprocess_images(examples):
def process_image(image):
if isinstance(image, str):
image = Image.open(image)
if not isinstance(image, Image.Image):
raise ValueError(f"Unexpected image type: {type(image)}")
image = image.convert("RGB").resize((512, 512))
image = np.array(image).astype(np.float32) / 255.0
return image.transpose(2, 0, 1)
return {"pixel_values": [process_image(img) for img in examples["image"]]}
# Load dataset from Hugging Face
dataset_name = "uruguayai/montevideo"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
print(f"Dataset name: {dataset_name}")
print(f"Dataset cache file: {dataset_cache_file}")
try:
if os.path.exists(dataset_cache_file):
print("Loading dataset from cache...")
with open(dataset_cache_file, 'rb') as f:
processed_dataset = pickle.load(f)
else:
print("Loading dataset from Hugging Face...")
dataset = load_dataset(dataset_name, split="train[:500]") # Load only first 500 samples
print("Processing dataset...")
processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
with open(dataset_cache_file, 'wb') as f:
pickle.dump(processed_dataset, f)
print(f"Processed dataset size: {len(processed_dataset)}")
except Exception as e:
print(f"Error loading or processing dataset: {str(e)}")
raise ValueError("Unable to load or process the dataset.")
# Function to clear JIT cache
def clear_jit_cache():
jax.clear_caches()
gc.collect()
# Training function
def train_step(state, batch, rng):
def compute_loss(params, pixel_values, rng):
# Ensure pixel_values are float32
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
# Encode images to latent space
latents = pipeline.vae.apply(
{"params": params["vae"]},
pixel_values,
method=pipeline.vae.encode
).latent_dist.sample(rng)
latents = latents * jnp.float32(0.18215)
# Generate random noise
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
# Sample random timesteps (keep as integers)
timesteps = jax.random.randint(
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
)
# Explicitly cast timesteps to int32
timesteps = timesteps.astype(jnp.int32)
# Add noise to latents
noisy_latents = pipeline.scheduler.add_noise(
pipeline.scheduler.create_state(),
original_samples=latents,
noise=noise,
timesteps=timesteps
)
# Generate random encoder hidden states (simulating text embeddings)
encoder_hidden_states = jax.random.normal(
rng,
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
dtype=jnp.float32
)
# Predict noise
model_output = state.apply_fn(
{'params': params["unet"]},
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
train=True,
)
# Compute loss
return jnp.mean((model_output - noise) ** 2)
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
rng, step_rng = jax.random.split(rng)
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
state = state.apply_gradients(grads=grads)
return state, loss
# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
state = train_state.TrainState.create(
apply_fn=unet.__call__, # Use __call__ directly
params=params, # Pass all params
tx=optimizer,
)
# Modify the train_step function
def train_step(state, batch, rng):
def compute_loss(params, pixel_values, rng):
# Ensure pixel_values are float32
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
# Encode images to latent space
latents = pipeline.vae.apply(
{"params": params["vae"]},
pixel_values,
method=pipeline.vae.encode
).latent_dist.sample(rng)
latents = latents * jnp.float32(0.18215)
# Generate random noise
noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
# Sample random timesteps
timesteps = jax.random.randint(
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
)
timesteps = jnp.array(timesteps, dtype=jnp.float32)
# Add noise to latents
noisy_latents = pipeline.scheduler.add_noise(
pipeline.scheduler.create_state(),
original_samples=latents,
noise=noise,
timesteps=timesteps
)
# Generate random encoder hidden states (simulating text embeddings)
encoder_hidden_states = jax.random.normal(
rng,
(latents.shape[0], pipeline.text_encoder.config.hidden_size),
dtype=jnp.float32
)
# Predict noise
model_output = state.apply_fn(
{'params': params["unet"]},
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
train=True,
)
# Compute loss
return jnp.mean((model_output - noise) ** 2)
grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
rng, step_rng = jax.random.split(rng)
grads = grad_fn(state.params, batch["pixel_values"], step_rng)
loss = compute_loss(state.params, batch["pixel_values"], step_rng)
state = state.apply_gradients(grads=grads)
return state, loss
# Training loop (remains the same)
num_epochs = 3
batch_size = 1
rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
epoch_loss = 0
num_batches = 0
for batch in tqdm(processed_dataset.batch(batch_size)):
batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float32)
rng, step_rng = jax.random.split(rng)
state, loss = train_step(state, batch, step_rng)
epoch_loss += loss
num_batches += 1
if num_batches % 10 == 0:
clear_jit_cache()
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
clear_jit_cache()
# Save the fine-tuned model
output_dir = "/tmp/montevideo_fine_tuned_model"
os.makedirs(output_dir, exist_ok=True)
unet.save_pretrained(output_dir, params=state.params["unet"])
print(f"Model saved to {output_dir}")