trainflux / app.py
uruguayai's picture
Update app.py
649234d verified
raw
history blame
6 kB
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np
# Set up cache directories
cache_dir = "/tmp/huggingface_cache"
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)
print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")
# Function to load or download the model
def get_model(model_id, revision):
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
print(f"Model cache file: {model_cache_file}")
if os.path.exists(model_cache_file):
print("Loading model from cache...")
with open(model_cache_file, 'rb') as f:
return pickle.load(f)
else:
print("Downloading model...")
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision=revision,
dtype=jnp.float32,
)
with open(model_cache_file, 'wb') as f:
pickle.dump((pipeline, params), f)
return pipeline, params
# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")
# Extract UNet from pipeline
unet = pipeline.unet
# Load and preprocess your dataset
def preprocess_images(examples):
def process_image(image):
if isinstance(image, str):
image = Image.open(image)
if not isinstance(image, Image.Image):
raise ValueError(f"Unexpected image type: {type(image)}")
return np.array(image.convert("RGB").resize((512, 512))).astype(np.float32) / 127.5 - 1.0
return {"pixel_values": [process_image(img) for img in examples["image"]]}
# Load dataset from Hugging Face
dataset_name = "uruguayai/montevideo"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
print(f"Dataset name: {dataset_name}")
print(f"Dataset cache file: {dataset_cache_file}")
try:
if os.path.exists(dataset_cache_file):
print("Loading dataset from cache...")
with open(dataset_cache_file, 'rb') as f:
processed_dataset = pickle.load(f)
else:
print("Loading dataset from Hugging Face...")
dataset = load_dataset(dataset_name)
print("Dataset structure:", dataset)
print("Available splits:", dataset.keys())
if "train" not in dataset:
raise ValueError("The dataset does not contain a 'train' split.")
print("Processing dataset...")
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
with open(dataset_cache_file, 'wb') as f:
pickle.dump(processed_dataset, f)
print(f"Processed dataset size: {len(processed_dataset)}")
except Exception as e:
print(f"Error loading or processing dataset: {str(e)}")
print("Attempting to find dataset...")
# List contents of current directory and parent directories
print("Current directory contents:")
print(os.listdir('.'))
print("Parent directory contents:")
print(os.listdir('..'))
print("Root directory contents:")
print(os.listdir('/'))
# Try to find any directory that might contain the dataset
for root, dirs, files in os.walk('/'):
if 'montevideo' in dirs:
print(f"Found 'montevideo' directory at: {os.path.join(root, 'montevideo')}")
print(f"Contents: {os.listdir(os.path.join(root, 'montevideo'))}")
raise ValueError("Unable to locate or load the dataset. Please check the dataset path and permissions.")
# Training function
def train_step(state, batch, rng):
def compute_loss(params):
# Convert batch to JAX array
pixel_values = jnp.array(batch["pixel_values"])
batch_size = pixel_values.shape[0]
# Generate random noise
noise_rng, timestep_rng = jax.random.split(rng)
noise = jax.random.normal(noise_rng, pixel_values.shape)
# Sample random timesteps
timesteps = jax.random.randint(
timestep_rng, (batch_size,), 0, pipeline.scheduler.config.num_train_timesteps
)
# Add noise to images using the scheduler
noisy_images = pipeline.scheduler.add_noise(
original_samples=pixel_values,
noise=noise,
timesteps=timesteps
)
# Predict noise
model_output = state.apply_fn.apply(
{'params': params},
jnp.array(noisy_images),
jnp.array(timesteps),
train=True,
)
# Compute loss
loss = jnp.mean((model_output - noise) ** 2)
return loss
loss, grads = jax.value_and_grad(compute_loss)(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
state = train_state.TrainState.create(
apply_fn=unet,
params=params["unet"], # Use only UNet params
tx=optimizer,
)
# Training loop
num_epochs = 10
batch_size = 4
rng = jax.random.PRNGKey(0)
for epoch in range(num_epochs):
epoch_loss = 0
num_batches = 0
for batch in tqdm(processed_dataset.batch(batch_size)):
rng, step_rng = jax.random.split(rng)
state, loss = train_step(state, batch, step_rng)
epoch_loss += loss
num_batches += 1
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
# Save the fine-tuned model
output_dir = "/tmp/montevideo_fine_tuned_model"
os.makedirs(output_dir, exist_ok=True)
unet.save_pretrained(output_dir, params=state.params)
print(f"Model saved to {output_dir}")