trainflux / app.py
uruguayai's picture
Update app.py
77248af verified
raw
history blame
3.99 kB
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np
# Set up cache directories
cache_dir = "/tmp/huggingface_cache"
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)
print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")
# Function to load or download the model
def get_model(model_id, revision):
model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
print(f"Model cache file: {model_cache_file}")
if os.path.exists(model_cache_file):
print("Loading model from cache...")
with open(model_cache_file, 'rb') as f:
return pickle.load(f)
else:
print("Downloading model...")
pipeline = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
revision=revision,
dtype=jnp.float32,
)
params = pipeline.params
with open(model_cache_file, 'wb') as f:
pickle.dump((pipeline, params), f)
return pipeline, params
# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")
# Extract UNet from pipeline
unet = pipeline.unet
# Load and preprocess your dataset
def preprocess_images(examples):
def process_image(image):
if isinstance(image, str):
image = Image.open(image)
if not isinstance(image, Image.Image):
raise ValueError(f"Unexpected image type: {type(image)}")
return np.array(image.convert("RGB").resize((512, 512))).astype(np.float32) / 127.5 - 1.0
return {"pixel_values": [process_image(img) for img in examples["image"]]}
# Load dataset from Hugging Face
dataset_name = "uruguayai/montevideo"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
print(f"Dataset name: {dataset_name}")
print(f"Dataset cache file: {dataset_cache_file}")
try:
if os.path.exists(dataset_cache_file):
print("Loading dataset from cache...")
with open(dataset_cache_file, 'rb') as f:
processed_dataset = pickle.load(f)
else:
print("Loading dataset from Hugging Face...")
dataset = load_dataset(dataset_name)
print("Dataset structure:", dataset)
print("Available splits:", dataset.keys())
if "train" not in dataset:
raise ValueError("The dataset does not contain a 'train' split.")
print("Processing dataset...")
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
with open(dataset_cache_file, 'wb') as f:
pickle.dump(processed_dataset, f)
print(f"Processed dataset size: {len(processed_dataset)}")
except Exception as e:
print(f"Error loading or processing dataset: {str(e)}")
print("Attempting to load from local path...")
local_path = "/home/user/app/uruguayai/montevideo"
if os.path.exists(local_path):
print(f"Local path exists. Contents: {os.listdir(local_path)}")
dataset = load_dataset("imagefolder", data_dir=local_path)
print("Dataset structure:", dataset)
print("Available splits:", dataset.keys())
if "train" in dataset:
processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
print(f"Processed dataset size: {len(processed_dataset)}")
else:
raise ValueError("The local dataset does not contain a 'train' split.")
else:
raise ValueError(f"Local path {local_path} does not exist.")
# Rest of your code (training loop, etc.) remains the same
...