|
import jax |
|
import jax.numpy as jnp |
|
from flax.jax_utils import replicate |
|
from flax.training import train_state |
|
import optax |
|
from diffusers import FlaxStableDiffusionPipeline |
|
from datasets import load_dataset |
|
from tqdm.auto import tqdm |
|
import os |
|
import pickle |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") |
|
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model") |
|
os.makedirs(model_cache_dir, exist_ok=True) |
|
|
|
print(f"Cache directory: {cache_dir}") |
|
print(f"Model cache directory: {model_cache_dir}") |
|
|
|
|
|
def get_model(model_id, revision): |
|
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl") |
|
print(f"Model cache file: {model_cache_file}") |
|
if os.path.exists(model_cache_file): |
|
print("Loading model from cache...") |
|
with open(model_cache_file, 'rb') as f: |
|
return pickle.load(f) |
|
else: |
|
print("Downloading model...") |
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( |
|
model_id, |
|
revision=revision, |
|
dtype=jnp.float32, |
|
) |
|
with open(model_cache_file, 'wb') as f: |
|
pickle.dump((pipeline, params), f) |
|
return pipeline, params |
|
|
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
pipeline, params = get_model(model_id, "flax") |
|
|
|
|
|
unet = pipeline.unet |
|
unet_params = params["unet"] |
|
|
|
|
|
input_channels = 3 |
|
unet_params['conv_in']['kernel'] = jax.random.normal( |
|
jax.random.PRNGKey(0), |
|
(3, 3, input_channels, unet_params['conv_in']['kernel'].shape[-1]) |
|
) |
|
|
|
|
|
learning_rate = 1e-5 |
|
optimizer = optax.adam(learning_rate) |
|
state = train_state.TrainState.create( |
|
apply_fn=unet, |
|
params=unet_params, |
|
tx=optimizer, |
|
) |
|
|
|
|
|
def preprocess_images(examples): |
|
def process_image(image): |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
if not isinstance(image, Image.Image): |
|
raise ValueError(f"Unexpected image type: {type(image)}") |
|
|
|
image = image.convert("RGBA") |
|
|
|
image = image.resize((512, 512)) |
|
|
|
image_array = np.array(image).astype(np.float32) / 127.5 - 1.0 |
|
|
|
return image_array |
|
|
|
return {"pixel_values": [process_image(img) for img in examples["image"]]} |
|
|
|
|
|
dataset_path = "C:/Users/Admin/Downloads/Montevideo/Output" |
|
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl") |
|
|
|
print(f"Dataset path: {dataset_path}") |
|
print(f"Dataset cache file: {dataset_cache_file}") |
|
|
|
if os.path.exists(dataset_cache_file): |
|
print("Loading dataset from cache...") |
|
with open(dataset_cache_file, 'rb') as f: |
|
processed_dataset = pickle.load(f) |
|
else: |
|
print("Processing dataset...") |
|
dataset = load_dataset("imagefolder", data_dir=dataset_path) |
|
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names) |
|
with open(dataset_cache_file, 'wb') as f: |
|
pickle.dump(processed_dataset, f) |
|
|
|
print(f"Processed dataset size: {len(processed_dataset)}") |
|
|
|
|
|
def train_step(state, batch, rng, scheduler, text_encoder): |
|
def compute_loss(params): |
|
|
|
pixel_values = jnp.array(batch["pixel_values"]) |
|
batch_size = pixel_values.shape[0] |
|
|
|
|
|
pixel_values = jnp.transpose(pixel_values, (0, 3, 1, 2)) |
|
|
|
|
|
noise_rng, timestep_rng = jax.random.split(rng) |
|
noise = jax.random.normal(noise_rng, pixel_values.shape) |
|
|
|
|
|
timesteps = jax.random.randint( |
|
timestep_rng, (batch_size,), 0, scheduler.config.num_train_timesteps |
|
) |
|
|
|
|
|
scheduler_state = scheduler.create_state() |
|
noisy_images = scheduler.add_noise(scheduler_state, pixel_values, noise, timesteps) |
|
|
|
|
|
encoder_hidden_states = jax.random.normal( |
|
noise_rng, (batch_size, 77, 768) |
|
) |
|
|
|
|
|
print("Input shape:", noisy_images.shape) |
|
print("Conv_in kernel shape:", params['conv_in']['kernel'].shape) |
|
|
|
|
|
model_output = state.apply_fn.apply( |
|
{'params': params}, |
|
jnp.array(noisy_images), |
|
jnp.array(timesteps), |
|
encoder_hidden_states=encoder_hidden_states, |
|
train=True, |
|
) |
|
|
|
|
|
loss = jnp.mean((model_output - noise) ** 2) |
|
return loss |
|
|
|
loss, grads = jax.value_and_grad(compute_loss)(state.params) |
|
state = state.apply_gradients(grads=grads) |
|
return state, loss |
|
|
|
|
|
|
|
|
|
learning_rate = 1e-5 |
|
optimizer = optax.adam(learning_rate) |
|
state = train_state.TrainState.create( |
|
apply_fn=unet, |
|
params=unet_params, |
|
tx=optimizer, |
|
) |
|
|
|
|
|
|
|
text_encoder = pipeline.text_encoder |
|
|
|
|
|
num_epochs = 10 |
|
batch_size = 4 |
|
rng = jax.random.PRNGKey(0) |
|
|
|
for epoch in range(num_epochs): |
|
epoch_loss = 0 |
|
num_batches = 0 |
|
for batch in tqdm(processed_dataset.batch(batch_size)): |
|
rng, step_rng = jax.random.split(rng) |
|
state, loss = train_step(state, batch, step_rng, pipeline.scheduler, text_encoder) |
|
epoch_loss += loss |
|
num_batches += 1 |
|
avg_loss = epoch_loss / num_batches |
|
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}") |
|
|
|
|
|
output_dir = "montevideo_fine_tuned_model" |
|
unet.save_pretrained(output_dir, params=state.params) |
|
|
|
print(f"Model saved to {output_dir}") |
|
|