|
import jax |
|
import jax.numpy as jnp |
|
from flax.jax_utils import replicate |
|
from flax.training import train_state |
|
import optax |
|
from diffusers import FlaxStableDiffusionPipeline |
|
from datasets import load_dataset |
|
from tqdm.auto import tqdm |
|
import os |
|
import pickle |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
cache_dir = "/tmp/huggingface_cache" |
|
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model") |
|
os.makedirs(model_cache_dir, exist_ok=True) |
|
|
|
print(f"Cache directory: {cache_dir}") |
|
print(f"Model cache directory: {model_cache_dir}") |
|
|
|
|
|
def get_model(model_id, revision): |
|
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl") |
|
print(f"Model cache file: {model_cache_file}") |
|
if os.path.exists(model_cache_file): |
|
print("Loading model from cache...") |
|
with open(model_cache_file, 'rb') as f: |
|
return pickle.load(f) |
|
else: |
|
print("Downloading model...") |
|
pipeline = FlaxStableDiffusionPipeline.from_pretrained( |
|
model_id, |
|
revision=revision, |
|
dtype=jnp.float32, |
|
) |
|
params = pipeline.params |
|
with open(model_cache_file, 'wb') as f: |
|
pickle.dump((pipeline, params), f) |
|
return pipeline, params |
|
|
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
pipeline, params = get_model(model_id, "flax") |
|
|
|
|
|
unet = pipeline.unet |
|
|
|
|
|
def preprocess_images(examples): |
|
def process_image(image): |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
if not isinstance(image, Image.Image): |
|
raise ValueError(f"Unexpected image type: {type(image)}") |
|
return np.array(image.convert("RGB").resize((512, 512))).astype(np.float32) / 127.5 - 1.0 |
|
|
|
return {"pixel_values": [process_image(img) for img in examples["image"]]} |
|
|
|
|
|
dataset_name = "uruguayai/montevideo" |
|
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl") |
|
|
|
print(f"Dataset name: {dataset_name}") |
|
print(f"Dataset cache file: {dataset_cache_file}") |
|
|
|
try: |
|
if os.path.exists(dataset_cache_file): |
|
print("Loading dataset from cache...") |
|
with open(dataset_cache_file, 'rb') as f: |
|
processed_dataset = pickle.load(f) |
|
else: |
|
print("Loading dataset from Hugging Face...") |
|
dataset = load_dataset(dataset_name) |
|
print("Dataset structure:", dataset) |
|
print("Available splits:", dataset.keys()) |
|
|
|
if "train" not in dataset: |
|
raise ValueError("The dataset does not contain a 'train' split.") |
|
|
|
print("Processing dataset...") |
|
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names) |
|
with open(dataset_cache_file, 'wb') as f: |
|
pickle.dump(processed_dataset, f) |
|
|
|
print(f"Processed dataset size: {len(processed_dataset)}") |
|
|
|
except Exception as e: |
|
print(f"Error loading or processing dataset: {str(e)}") |
|
print("Attempting to load from local path...") |
|
local_path = "/home/user/app/uruguayai/montevideo" |
|
if os.path.exists(local_path): |
|
print(f"Local path exists. Contents: {os.listdir(local_path)}") |
|
dataset = load_dataset("imagefolder", data_dir=local_path) |
|
print("Dataset structure:", dataset) |
|
print("Available splits:", dataset.keys()) |
|
if "train" in dataset: |
|
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names) |
|
print(f"Processed dataset size: {len(processed_dataset)}") |
|
else: |
|
raise ValueError("The local dataset does not contain a 'train' split.") |
|
else: |
|
raise ValueError(f"Local path {local_path} does not exist.") |
|
|
|
|
|
... |