import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np

# Set up cache directories
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)

print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")

# Function to load or download the model
def get_model(model_id, revision):
    model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
    print(f"Model cache file: {model_cache_file}")
    if os.path.exists(model_cache_file):
        print("Loading model from cache...")
        with open(model_cache_file, 'rb') as f:
            return pickle.load(f)
    else:
        print("Downloading model...")
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            model_id, 
            revision=revision,
            dtype=jnp.float32,
        )
        with open(model_cache_file, 'wb') as f:
            pickle.dump((pipeline, params), f)
        return pipeline, params

# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")

# Extract UNet and its parameters
unet = pipeline.unet
unet_params = params["unet"]

# Modify the conv_in layer to match the input shape
input_channels = 3  # RGB images
unet_params['conv_in']['kernel'] = jax.random.normal(
    jax.random.PRNGKey(0),
    (3, 3, input_channels, unet_params['conv_in']['kernel'].shape[-1])
)

# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
state = train_state.TrainState.create(
    apply_fn=unet,
    params=unet_params,
    tx=optimizer,
)

# Load and preprocess your dataset
def preprocess_images(examples):
    def process_image(image):
        if isinstance(image, str):
            image = Image.open(image)
        if not isinstance(image, Image.Image):
            raise ValueError(f"Unexpected image type: {type(image)}")
        # Ensure the image is in RGBA mode (4 channels)
        image = image.convert("RGBA")
        # Resize the image
        image = image.resize((512, 512))
        # Convert to numpy array and normalize
        image_array = np.array(image).astype(np.float32) / 127.5 - 1.0
        # Ensure the array has shape (height, width, 4)
        return image_array

    return {"pixel_values": [process_image(img) for img in examples["image"]]}

# Load dataset with caching
dataset_path = "C:/Users/Admin/Downloads/Montevideo/Output"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")

print(f"Dataset path: {dataset_path}")
print(f"Dataset cache file: {dataset_cache_file}")

if os.path.exists(dataset_cache_file):
    print("Loading dataset from cache...")
    with open(dataset_cache_file, 'rb') as f:
        processed_dataset = pickle.load(f)
else:
    print("Processing dataset...")
    dataset = load_dataset("imagefolder", data_dir=dataset_path)
    processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
    with open(dataset_cache_file, 'wb') as f:
        pickle.dump(processed_dataset, f)

print(f"Processed dataset size: {len(processed_dataset)}")

# Training function
def train_step(state, batch, rng, scheduler, text_encoder):
    def compute_loss(params):
        # Convert batch to JAX array
        pixel_values = jnp.array(batch["pixel_values"])
        batch_size = pixel_values.shape[0]
        
        # Reshape pixel_values to match the expected input shape (NCHW format)
        pixel_values = jnp.transpose(pixel_values, (0, 3, 1, 2))  # NHWC to NCHW
        
        # Generate random noise
        noise_rng, timestep_rng = jax.random.split(rng)
        noise = jax.random.normal(noise_rng, pixel_values.shape)
        
        # Sample random timesteps
        timesteps = jax.random.randint(
            timestep_rng, (batch_size,), 0, scheduler.config.num_train_timesteps
        )
        
        # Generate noisy images
        scheduler_state = scheduler.create_state()
        noisy_images = scheduler.add_noise(scheduler_state, pixel_values, noise, timesteps)
        
        # Generate random encoder_hidden_states (text embeddings)
        encoder_hidden_states = jax.random.normal(
            noise_rng, (batch_size, 77, 768)
        )
        
        # Print shapes for debugging
        print("Input shape:", noisy_images.shape)
        print("Conv_in kernel shape:", params['conv_in']['kernel'].shape)
        
        # Predict noise
        model_output = state.apply_fn.apply(
            {'params': params},
            jnp.array(noisy_images),
            jnp.array(timesteps),
            encoder_hidden_states=encoder_hidden_states,
            train=True,
        )
        
        # Compute loss
        loss = jnp.mean((model_output - noise) ** 2)
        return loss

    loss, grads = jax.value_and_grad(compute_loss)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss



# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
state = train_state.TrainState.create(
    apply_fn=unet,
    params=unet_params,
    tx=optimizer,
)

# Training loop
# Extract text encoder from pipeline
text_encoder = pipeline.text_encoder

# Training loop
num_epochs = 10
batch_size = 4
rng = jax.random.PRNGKey(0)

for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = 0
    for batch in tqdm(processed_dataset.batch(batch_size)):
        rng, step_rng = jax.random.split(rng)
        state, loss = train_step(state, batch, step_rng, pipeline.scheduler, text_encoder)
        epoch_loss += loss
        num_batches += 1
    avg_loss = epoch_loss / num_batches
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
    
# Save the fine-tuned model
output_dir = "montevideo_fine_tuned_model"
unet.save_pretrained(output_dir, params=state.params)

print(f"Model saved to {output_dir}")