File size: 7,816 Bytes
de0db89
 
3518b5f
de0db89
 
0b99dda
3518b5f
de0db89
 
 
 
 
 
35bc545
3518b5f
 
acc7f4b
 
 
 
de0db89
77248af
de0db89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f8900f
de0db89
 
3518b5f
de0db89
 
 
 
 
 
 
 
 
3518b5f
 
acc7f4b
 
967b314
 
 
66bb520
 
 
 
967b314
 
 
66bb520
967b314
7166f76
967b314
66bb520
967b314
 
 
16dd569
de0db89
 
 
 
7166f76
 
de0db89
 
7166f76
4434e29
920c999
5f8640f
de0db89
7166f76
 
de0db89
77248af
 
de0db89
 
77248af
de0db89
 
3518b5f
 
 
 
 
 
 
 
7166f76
3518b5f
 
649234d
3518b5f
de0db89
66bb520
 
ad3bf4c
 
 
 
 
66bb520
920c999
cc5a61c
 
8e214b7
5f8640f
ad3bf4c
8e214b7
 
 
 
 
 
 
ad3bf4c
8e214b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66bb520
 
 
 
dacaf33
e9745d9
8e214b7
76dfe67
dacaf33
c8658d7
8e214b7
dacaf33
8e214b7
 
 
 
 
 
 
 
 
 
 
 
 
 
5f8640f
8e214b7
e9745d9
8e214b7
 
 
 
 
 
 
9aba976
 
 
 
 
 
ad3bf4c
9aba976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training import train_state
import optax
from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
from diffusers.schedulers import FlaxPNDMScheduler
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import pickle
from PIL import Image
import numpy as np

# Custom Scheduler
class CustomFlaxPNDMScheduler(FlaxPNDMScheduler):
    def add_noise(self, state, original_samples, noise, timesteps):
        timesteps = timesteps.astype(jnp.int32)
        return super().add_noise(state, original_samples, noise, timesteps)

# Set up cache directories
cache_dir = "/tmp/huggingface_cache"
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
os.makedirs(model_cache_dir, exist_ok=True)

print(f"Cache directory: {cache_dir}")
print(f"Model cache directory: {model_cache_dir}")

# Function to load or download the model
def get_model(model_id, revision):
    model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
    print(f"Model cache file: {model_cache_file}")
    if os.path.exists(model_cache_file):
        print("Loading model from cache...")
        with open(model_cache_file, 'rb') as f:
            return pickle.load(f)
    else:
        print("Downloading model...")
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            model_id, 
            revision=revision,
            dtype=jnp.float32,
        )
        with open(model_cache_file, 'wb') as f:
            pickle.dump((pipeline, params), f)
        return pipeline, params

# Load the pre-trained model
model_id = "CompVis/stable-diffusion-v1-4"
pipeline, params = get_model(model_id, "flax")

# Use custom scheduler
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler = custom_scheduler

# Extract UNet from pipeline
unet = pipeline.unet

# Print UNet configuration
print("UNet configuration:")
print(unet.config)

# Adjust the input layer of the UNet
def adjust_unet_input_layer(params):
    conv_in_weight = params['unet']['conv_in']['kernel']
    print(f"Original conv_in weight shape: {conv_in_weight.shape}")
    new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
    new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
    params['unet']['conv_in']['kernel'] = new_conv_in_weight
    print(f"New conv_in weight shape: {params['unet']['conv_in']['kernel'].shape}")
    return params

params = adjust_unet_input_layer(params)

# Load and preprocess your dataset
def preprocess_images(examples):
    def process_image(image):
        if isinstance(image, str):
            if not image.lower().endswith('.jpg') and not image.lower().endswith('.jpeg'):
                return None
            image = Image.open(image)
        if not isinstance(image, Image.Image):
            return None
        image = image.convert("RGB").resize((512, 512))
        image = np.array(image).astype(np.float32) / 255.0
        return image

    processed = [process_image(img) for img in examples["image"]]
    return {"pixel_values": [img for img in processed if img is not None]}

# Load dataset from Hugging Face
dataset_name = "uruguayai/montevideo"
dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")

print(f"Dataset name: {dataset_name}")
print(f"Dataset cache file: {dataset_cache_file}")

if os.path.exists(dataset_cache_file):
    print("Loading dataset from cache...")
    with open(dataset_cache_file, 'rb') as f:
        processed_dataset = pickle.load(f)
else:
    print("Processing dataset...")
    dataset = load_dataset(dataset_name)
    processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
    processed_dataset = processed_dataset.filter(lambda example: len(example['pixel_values']) > 0)
    with open(dataset_cache_file, 'wb') as f:
        pickle.dump(processed_dataset, f)

print(f"Processed dataset size: {len(processed_dataset)}")

# Print sample input shape
sample_batch = next(iter(processed_dataset.batch(1)))
print(f"Sample batch keys: {sample_batch.keys()}")
print(f"Sample pixel_values type: {type(sample_batch['pixel_values'])}")
print(f"Sample pixel_values length: {len(sample_batch['pixel_values'])}")
if len(sample_batch['pixel_values']) > 0:
    print(f"Sample pixel_values[0] shape: {np.array(sample_batch['pixel_values'][0]).shape}")

# Training function
def train_step(state, batch, rng):
    def compute_loss(params, pixel_values, rng):
        pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
        pixel_values = jnp.expand_dims(pixel_values, axis=0)  # Add batch dimension
        print(f"pixel_values shape in compute_loss: {pixel_values.shape}")
        
        latents = pipeline.vae.apply(
            {"params": params["vae"]},
            pixel_values,
            method=pipeline.vae.encode
        ).latent_dist.sample(rng)
        latents = latents * jnp.float32(0.18215)
        print(f"latents shape: {latents.shape}")

        noise = jax.random.normal(rng, latents.shape, dtype=jnp.float32)
        
        timesteps = jax.random.randint(
            rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
        )
        
        noisy_latents = pipeline.scheduler.add_noise(
            pipeline.scheduler.create_state(),
            original_samples=latents,
            noise=noise,
            timesteps=timesteps
        )
        
        encoder_hidden_states = jax.random.normal(
            rng, 
            (latents.shape[0], pipeline.text_encoder.config.hidden_size),
            dtype=jnp.float32
        )
        
        print(f"noisy_latents shape: {noisy_latents.shape}")
        print(f"timesteps shape: {timesteps.shape}")
        print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
        
        # Use the correct method to call the UNet
        model_output = unet.apply(
            {'params': params["unet"]},
            noisy_latents,
            jnp.array(timesteps, dtype=jnp.int32),
            encoder_hidden_states,
            train=True,
        ).sample
        
        return jnp.mean((model_output - noise) ** 2)

    grad_fn = jax.grad(compute_loss, argnums=0, allow_int=True)
    rng, step_rng = jax.random.split(rng)
    
    grads = grad_fn(state.params, batch["pixel_values"], step_rng)
    loss = compute_loss(state.params, batch["pixel_values"], step_rng)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Initialize training state
learning_rate = 1e-5
optimizer = optax.adam(learning_rate)
float32_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype != jnp.int32 else x, params)
state = train_state.TrainState.create(
    apply_fn=unet.apply,
    params=float32_params,
    tx=optimizer,
)

# Training loop
num_epochs = 3
batch_size = 1
rng = jax.random.PRNGKey(0)

for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = 0
    for batch in tqdm(processed_dataset.batch(batch_size)):
        batch['pixel_values'] = jnp.array(batch['pixel_values'][0], dtype=jnp.float32)
        rng, step_rng = jax.random.split(rng)
        state, loss = train_step(state, batch, step_rng)
        epoch_loss += loss
        num_batches += 1
        
        if num_batches % 10 == 0:
            jax.clear_caches()
    
    avg_loss = epoch_loss / num_batches
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
    jax.clear_caches()

# Save the fine-tuned model
output_dir = "/tmp/montevideo_fine_tuned_model"
os.makedirs(output_dir, exist_ok=True)
unet.save_pretrained(output_dir, params=state.params["unet"])

print(f"Model saved to {output_dir}")