uruguayai commited on
Commit
de0db89
·
verified ·
1 Parent(s): 77d8758

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax.jax_utils import replicate
4
+ from flax.training import train_state
5
+ import optax
6
+ from diffusers import FlaxStableDiffusionPipeline
7
+ from datasets import load_dataset
8
+ from tqdm.auto import tqdm
9
+ import os
10
+ import pickle
11
+ from PIL import Image
12
+ import numpy as np
13
+
14
+ # Set up cache directories
15
+ cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
16
+ model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
17
+ os.makedirs(model_cache_dir, exist_ok=True)
18
+
19
+ print(f"Cache directory: {cache_dir}")
20
+ print(f"Model cache directory: {model_cache_dir}")
21
+
22
+ # Function to load or download the model
23
+ def get_model(model_id, revision):
24
+ model_cache_file = os.path.join(model_cache_dir, f"{model_id.replace('/', '_')}_{revision}.pkl")
25
+ print(f"Model cache file: {model_cache_file}")
26
+ if os.path.exists(model_cache_file):
27
+ print("Loading model from cache...")
28
+ with open(model_cache_file, 'rb') as f:
29
+ return pickle.load(f)
30
+ else:
31
+ print("Downloading model...")
32
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
33
+ model_id,
34
+ revision=revision,
35
+ dtype=jnp.float32,
36
+ )
37
+ with open(model_cache_file, 'wb') as f:
38
+ pickle.dump((pipeline, params), f)
39
+ return pipeline, params
40
+
41
+ # Load the pre-trained model
42
+ model_id = "CompVis/stable-diffusion-v1-4"
43
+ pipeline, params = get_model(model_id, "flax")
44
+
45
+ # Extract UNet and its parameters
46
+ unet = pipeline.unet
47
+ unet_params = params["unet"]
48
+
49
+ # Modify the conv_in layer to match the input shape
50
+ input_channels = 3 # RGB images
51
+ unet_params['conv_in']['kernel'] = jax.random.normal(
52
+ jax.random.PRNGKey(0),
53
+ (3, 3, input_channels, unet_params['conv_in']['kernel'].shape[-1])
54
+ )
55
+
56
+ # Initialize training state
57
+ learning_rate = 1e-5
58
+ optimizer = optax.adam(learning_rate)
59
+ state = train_state.TrainState.create(
60
+ apply_fn=unet,
61
+ params=unet_params,
62
+ tx=optimizer,
63
+ )
64
+
65
+ # Load and preprocess your dataset
66
+ def preprocess_images(examples):
67
+ def process_image(image):
68
+ if isinstance(image, str):
69
+ image = Image.open(image)
70
+ if not isinstance(image, Image.Image):
71
+ raise ValueError(f"Unexpected image type: {type(image)}")
72
+ # Ensure the image is in RGBA mode (4 channels)
73
+ image = image.convert("RGBA")
74
+ # Resize the image
75
+ image = image.resize((512, 512))
76
+ # Convert to numpy array and normalize
77
+ image_array = np.array(image).astype(np.float32) / 127.5 - 1.0
78
+ # Ensure the array has shape (height, width, 4)
79
+ return image_array
80
+
81
+ return {"pixel_values": [process_image(img) for img in examples["image"]]}
82
+
83
+ # Load dataset with caching
84
+ dataset_path = "C:/Users/Admin/Downloads/Montevideo/Output"
85
+ dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
86
+
87
+ print(f"Dataset path: {dataset_path}")
88
+ print(f"Dataset cache file: {dataset_cache_file}")
89
+
90
+ if os.path.exists(dataset_cache_file):
91
+ print("Loading dataset from cache...")
92
+ with open(dataset_cache_file, 'rb') as f:
93
+ processed_dataset = pickle.load(f)
94
+ else:
95
+ print("Processing dataset...")
96
+ dataset = load_dataset("imagefolder", data_dir=dataset_path)
97
+ processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
98
+ with open(dataset_cache_file, 'wb') as f:
99
+ pickle.dump(processed_dataset, f)
100
+
101
+ print(f"Processed dataset size: {len(processed_dataset)}")
102
+
103
+ # Training function
104
+ def train_step(state, batch, rng, scheduler, text_encoder):
105
+ def compute_loss(params):
106
+ # Convert batch to JAX array
107
+ pixel_values = jnp.array(batch["pixel_values"])
108
+ batch_size = pixel_values.shape[0]
109
+
110
+ # Reshape pixel_values to match the expected input shape (NCHW format)
111
+ pixel_values = jnp.transpose(pixel_values, (0, 3, 1, 2)) # NHWC to NCHW
112
+
113
+ # Generate random noise
114
+ noise_rng, timestep_rng = jax.random.split(rng)
115
+ noise = jax.random.normal(noise_rng, pixel_values.shape)
116
+
117
+ # Sample random timesteps
118
+ timesteps = jax.random.randint(
119
+ timestep_rng, (batch_size,), 0, scheduler.config.num_train_timesteps
120
+ )
121
+
122
+ # Generate noisy images
123
+ scheduler_state = scheduler.create_state()
124
+ noisy_images = scheduler.add_noise(scheduler_state, pixel_values, noise, timesteps)
125
+
126
+ # Generate random encoder_hidden_states (text embeddings)
127
+ encoder_hidden_states = jax.random.normal(
128
+ noise_rng, (batch_size, 77, 768)
129
+ )
130
+
131
+ # Print shapes for debugging
132
+ print("Input shape:", noisy_images.shape)
133
+ print("Conv_in kernel shape:", params['conv_in']['kernel'].shape)
134
+
135
+ # Predict noise
136
+ model_output = state.apply_fn.apply(
137
+ {'params': params},
138
+ jnp.array(noisy_images),
139
+ jnp.array(timesteps),
140
+ encoder_hidden_states=encoder_hidden_states,
141
+ train=True,
142
+ )
143
+
144
+ # Compute loss
145
+ loss = jnp.mean((model_output - noise) ** 2)
146
+ return loss
147
+
148
+ loss, grads = jax.value_and_grad(compute_loss)(state.params)
149
+ state = state.apply_gradients(grads=grads)
150
+ return state, loss
151
+
152
+
153
+
154
+ # Initialize training state
155
+ learning_rate = 1e-5
156
+ optimizer = optax.adam(learning_rate)
157
+ state = train_state.TrainState.create(
158
+ apply_fn=unet,
159
+ params=unet_params,
160
+ tx=optimizer,
161
+ )
162
+
163
+ # Training loop
164
+ # Extract text encoder from pipeline
165
+ text_encoder = pipeline.text_encoder
166
+
167
+ # Training loop
168
+ num_epochs = 10
169
+ batch_size = 4
170
+ rng = jax.random.PRNGKey(0)
171
+
172
+ for epoch in range(num_epochs):
173
+ epoch_loss = 0
174
+ num_batches = 0
175
+ for batch in tqdm(processed_dataset.batch(batch_size)):
176
+ rng, step_rng = jax.random.split(rng)
177
+ state, loss = train_step(state, batch, step_rng, pipeline.scheduler, text_encoder)
178
+ epoch_loss += loss
179
+ num_batches += 1
180
+ avg_loss = epoch_loss / num_batches
181
+ print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
182
+
183
+ # Save the fine-tuned model
184
+ output_dir = "montevideo_fine_tuned_model"
185
+ unet.save_pretrained(output_dir, params=state.params)
186
+
187
+ print(f"Model saved to {output_dir}")