uruguayai commited on
Commit
77248af
·
verified ·
1 Parent(s): ec5a3dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -126
app.py CHANGED
@@ -12,7 +12,7 @@ from PIL import Image
12
  import numpy as np
13
 
14
  # Set up cache directories
15
- cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
16
  model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
17
  os.makedirs(model_cache_dir, exist_ok=True)
18
 
@@ -29,11 +29,12 @@ def get_model(model_id, revision):
29
  return pickle.load(f)
30
  else:
31
  print("Downloading model...")
32
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
33
  model_id,
34
  revision=revision,
35
  dtype=jnp.float32,
36
  )
 
37
  with open(model_cache_file, 'wb') as f:
38
  pickle.dump((pipeline, params), f)
39
  return pipeline, params
@@ -42,25 +43,8 @@ def get_model(model_id, revision):
42
  model_id = "CompVis/stable-diffusion-v1-4"
43
  pipeline, params = get_model(model_id, "flax")
44
 
45
- # Extract UNet and its parameters
46
  unet = pipeline.unet
47
- unet_params = params["unet"]
48
-
49
- # Modify the conv_in layer to match the input shape
50
- input_channels = 3 # RGB images
51
- unet_params['conv_in']['kernel'] = jax.random.normal(
52
- jax.random.PRNGKey(0),
53
- (3, 3, input_channels, unet_params['conv_in']['kernel'].shape[-1])
54
- )
55
-
56
- # Initialize training state
57
- learning_rate = 1e-5
58
- optimizer = optax.adam(learning_rate)
59
- state = train_state.TrainState.create(
60
- apply_fn=unet,
61
- params=unet_params,
62
- tx=optimizer,
63
- )
64
 
65
  # Load and preprocess your dataset
66
  def preprocess_images(examples):
@@ -69,119 +53,54 @@ def preprocess_images(examples):
69
  image = Image.open(image)
70
  if not isinstance(image, Image.Image):
71
  raise ValueError(f"Unexpected image type: {type(image)}")
72
- # Ensure the image is in RGBA mode (4 channels)
73
- image = image.convert("RGBA")
74
- # Resize the image
75
- image = image.resize((512, 512))
76
- # Convert to numpy array and normalize
77
- image_array = np.array(image).astype(np.float32) / 127.5 - 1.0
78
- # Ensure the array has shape (height, width, 4)
79
- return image_array
80
 
81
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
82
 
83
- # Load dataset with caching
84
- dataset_path = "uruguayai/montevideo"
85
  dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
86
 
87
- print(f"Dataset path: {dataset_path}")
88
  print(f"Dataset cache file: {dataset_cache_file}")
89
 
90
- if os.path.exists(dataset_cache_file):
91
- print("Loading dataset from cache...")
92
- with open(dataset_cache_file, 'rb') as f:
93
- processed_dataset = pickle.load(f)
94
- else:
95
- print("Processing dataset...")
96
- dataset = load_dataset("imagefolder", data_dir=dataset_path)
97
- processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
98
- with open(dataset_cache_file, 'wb') as f:
99
- pickle.dump(processed_dataset, f)
100
-
101
- print(f"Processed dataset size: {len(processed_dataset)}")
102
-
103
- # Training function
104
- def train_step(state, batch, rng, scheduler, text_encoder):
105
- def compute_loss(params):
106
- # Convert batch to JAX array
107
- pixel_values = jnp.array(batch["pixel_values"])
108
- batch_size = pixel_values.shape[0]
109
-
110
- # Reshape pixel_values to match the expected input shape (NCHW format)
111
- pixel_values = jnp.transpose(pixel_values, (0, 3, 1, 2)) # NHWC to NCHW
112
-
113
- # Generate random noise
114
- noise_rng, timestep_rng = jax.random.split(rng)
115
- noise = jax.random.normal(noise_rng, pixel_values.shape)
116
-
117
- # Sample random timesteps
118
- timesteps = jax.random.randint(
119
- timestep_rng, (batch_size,), 0, scheduler.config.num_train_timesteps
120
- )
121
-
122
- # Generate noisy images
123
- scheduler_state = scheduler.create_state()
124
- noisy_images = scheduler.add_noise(scheduler_state, pixel_values, noise, timesteps)
125
-
126
- # Generate random encoder_hidden_states (text embeddings)
127
- encoder_hidden_states = jax.random.normal(
128
- noise_rng, (batch_size, 77, 768)
129
- )
130
-
131
- # Print shapes for debugging
132
- print("Input shape:", noisy_images.shape)
133
- print("Conv_in kernel shape:", params['conv_in']['kernel'].shape)
134
 
135
- # Predict noise
136
- model_output = state.apply_fn.apply(
137
- {'params': params},
138
- jnp.array(noisy_images),
139
- jnp.array(timesteps),
140
- encoder_hidden_states=encoder_hidden_states,
141
- train=True,
142
- )
143
 
144
- # Compute loss
145
- loss = jnp.mean((model_output - noise) ** 2)
146
- return loss
147
-
148
- loss, grads = jax.value_and_grad(compute_loss)(state.params)
149
- state = state.apply_gradients(grads=grads)
150
- return state, loss
151
-
152
-
153
-
154
- # Initialize training state
155
- learning_rate = 1e-5
156
- optimizer = optax.adam(learning_rate)
157
- state = train_state.TrainState.create(
158
- apply_fn=unet,
159
- params=unet_params,
160
- tx=optimizer,
161
- )
162
-
163
- # Training loop
164
- # Extract text encoder from pipeline
165
- text_encoder = pipeline.text_encoder
166
-
167
- # Training loop
168
- num_epochs = 10
169
- batch_size = 4
170
- rng = jax.random.PRNGKey(0)
171
-
172
- for epoch in range(num_epochs):
173
- epoch_loss = 0
174
- num_batches = 0
175
- for batch in tqdm(processed_dataset.batch(batch_size)):
176
- rng, step_rng = jax.random.split(rng)
177
- state, loss = train_step(state, batch, step_rng, pipeline.scheduler, text_encoder)
178
- epoch_loss += loss
179
- num_batches += 1
180
- avg_loss = epoch_loss / num_batches
181
- print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss}")
182
-
183
- # Save the fine-tuned model
184
- output_dir = "montevideo_fine_tuned_model"
185
- unet.save_pretrained(output_dir, params=state.params)
186
 
187
- print(f"Model saved to {output_dir}")
 
 
12
  import numpy as np
13
 
14
  # Set up cache directories
15
+ cache_dir = "/tmp/huggingface_cache"
16
  model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
17
  os.makedirs(model_cache_dir, exist_ok=True)
18
 
 
29
  return pickle.load(f)
30
  else:
31
  print("Downloading model...")
32
+ pipeline = FlaxStableDiffusionPipeline.from_pretrained(
33
  model_id,
34
  revision=revision,
35
  dtype=jnp.float32,
36
  )
37
+ params = pipeline.params
38
  with open(model_cache_file, 'wb') as f:
39
  pickle.dump((pipeline, params), f)
40
  return pipeline, params
 
43
  model_id = "CompVis/stable-diffusion-v1-4"
44
  pipeline, params = get_model(model_id, "flax")
45
 
46
+ # Extract UNet from pipeline
47
  unet = pipeline.unet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Load and preprocess your dataset
50
  def preprocess_images(examples):
 
53
  image = Image.open(image)
54
  if not isinstance(image, Image.Image):
55
  raise ValueError(f"Unexpected image type: {type(image)}")
56
+ return np.array(image.convert("RGB").resize((512, 512))).astype(np.float32) / 127.5 - 1.0
 
 
 
 
 
 
 
57
 
58
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
59
 
60
+ # Load dataset from Hugging Face
61
+ dataset_name = "uruguayai/montevideo"
62
  dataset_cache_file = os.path.join(cache_dir, "montevideo_dataset.pkl")
63
 
64
+ print(f"Dataset name: {dataset_name}")
65
  print(f"Dataset cache file: {dataset_cache_file}")
66
 
67
+ try:
68
+ if os.path.exists(dataset_cache_file):
69
+ print("Loading dataset from cache...")
70
+ with open(dataset_cache_file, 'rb') as f:
71
+ processed_dataset = pickle.load(f)
72
+ else:
73
+ print("Loading dataset from Hugging Face...")
74
+ dataset = load_dataset(dataset_name)
75
+ print("Dataset structure:", dataset)
76
+ print("Available splits:", dataset.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ if "train" not in dataset:
79
+ raise ValueError("The dataset does not contain a 'train' split.")
 
 
 
 
 
 
80
 
81
+ print("Processing dataset...")
82
+ processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
83
+ with open(dataset_cache_file, 'wb') as f:
84
+ pickle.dump(processed_dataset, f)
85
+
86
+ print(f"Processed dataset size: {len(processed_dataset)}")
87
+
88
+ except Exception as e:
89
+ print(f"Error loading or processing dataset: {str(e)}")
90
+ print("Attempting to load from local path...")
91
+ local_path = "/home/user/app/uruguayai/montevideo"
92
+ if os.path.exists(local_path):
93
+ print(f"Local path exists. Contents: {os.listdir(local_path)}")
94
+ dataset = load_dataset("imagefolder", data_dir=local_path)
95
+ print("Dataset structure:", dataset)
96
+ print("Available splits:", dataset.keys())
97
+ if "train" in dataset:
98
+ processed_dataset = dataset["train"].map(preprocess_images, batched=True, remove_columns=dataset["train"].column_names)
99
+ print(f"Processed dataset size: {len(processed_dataset)}")
100
+ else:
101
+ raise ValueError("The local dataset does not contain a 'train' split.")
102
+ else:
103
+ raise ValueError(f"Local path {local_path} does not exist.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Rest of your code (training loop, etc.) remains the same
106
+ ...