uruguayai commited on
Commit
920c999
·
verified ·
1 Parent(s): cc5a61c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import jax
2
  import jax.numpy as jnp
3
- from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
  from diffusers import FlaxStableDiffusionPipeline
@@ -12,6 +11,11 @@ from PIL import Image
12
  import numpy as np
13
  import gc
14
 
 
 
 
 
 
15
  # Set up cache directories
16
  cache_dir = "/tmp/huggingface_cache"
17
  model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
@@ -33,7 +37,7 @@ def get_model(model_id, revision):
33
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
34
  model_id,
35
  revision=revision,
36
- dtype=jnp.float16,
37
  )
38
  with open(model_cache_file, 'wb') as f:
39
  pickle.dump((pipeline, params), f)
@@ -53,8 +57,8 @@ def preprocess_images(examples):
53
  image = Image.open(image)
54
  if not isinstance(image, Image.Image):
55
  raise ValueError(f"Unexpected image type: {type(image)}")
56
- image = image.convert("RGB").resize((256, 256)) # Reduced image size
57
- image = np.array(image).astype(np.float16) / 255.0
58
  return image.transpose(2, 0, 1)
59
 
60
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
@@ -73,7 +77,7 @@ try:
73
  processed_dataset = pickle.load(f)
74
  else:
75
  print("Loading dataset from Hugging Face...")
76
- dataset = load_dataset(dataset_name, split="train[:1000]") # Load only first 1000 samples
77
  print("Processing dataset...")
78
  processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
79
  with open(dataset_cache_file, 'wb') as f:
@@ -90,8 +94,7 @@ def clear_jit_cache():
90
  jax.clear_caches()
91
  gc.collect()
92
 
93
- # Training function with gradient accumulation
94
- @jax.jit
95
  def train_step(state, batch, rng):
96
  def compute_loss(params, pixel_values, rng):
97
  latents = pipeline.vae.apply(
@@ -99,9 +102,9 @@ def train_step(state, batch, rng):
99
  pixel_values,
100
  method=pipeline.vae.encode
101
  ).latent_dist.sample(rng)
102
- latents = latents * jnp.float16(0.18215)
103
 
104
- noise = jax.random.normal(rng, latents.shape, dtype=jnp.float16)
105
  timesteps = jax.random.randint(
106
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
107
  )
@@ -114,8 +117,7 @@ def train_step(state, batch, rng):
114
 
115
  encoder_hidden_states = jax.random.normal(
116
  rng,
117
- (latents.shape[0], pipeline.text_encoder.config.hidden_size),
118
- dtype=jnp.float16
119
  )
120
 
121
  model_output = state.apply_fn.apply(
@@ -135,7 +137,7 @@ def train_step(state, batch, rng):
135
  return state, loss
136
 
137
  # Initialize training state
138
- learning_rate = jnp.float16(1e-5)
139
  optimizer = optax.adam(learning_rate)
140
  state = train_state.TrainState.create(
141
  apply_fn=unet,
@@ -144,15 +146,15 @@ state = train_state.TrainState.create(
144
  )
145
 
146
  # Training loop
147
- num_epochs = 5 # Reduced number of epochs
148
- batch_size = 4
149
  rng = jax.random.PRNGKey(0)
150
 
151
  for epoch in range(num_epochs):
152
  epoch_loss = 0
153
  num_batches = 0
154
  for batch in tqdm(processed_dataset.batch(batch_size)):
155
- batch['pixel_values'] = jnp.array(batch['pixel_values'], dtype=jnp.float16)
156
  rng, step_rng = jax.random.split(rng)
157
  state, loss = train_step(state, batch, step_rng)
158
  epoch_loss += loss
 
1
  import jax
2
  import jax.numpy as jnp
 
3
  from flax.training import train_state
4
  import optax
5
  from diffusers import FlaxStableDiffusionPipeline
 
11
  import numpy as np
12
  import gc
13
 
14
+ # Force JAX to use CPU
15
+ jax.config.update('jax_platform_name', 'cpu')
16
+
17
+ print("Using CPU for computations")
18
+
19
  # Set up cache directories
20
  cache_dir = "/tmp/huggingface_cache"
21
  model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
 
37
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
38
  model_id,
39
  revision=revision,
40
+ dtype=jnp.float32, # Use float32 for CPU
41
  )
42
  with open(model_cache_file, 'wb') as f:
43
  pickle.dump((pipeline, params), f)
 
57
  image = Image.open(image)
58
  if not isinstance(image, Image.Image):
59
  raise ValueError(f"Unexpected image type: {type(image)}")
60
+ image = image.convert("RGB").resize((128, 128)) # Further reduced image size
61
+ image = np.array(image).astype(np.float32) / 255.0
62
  return image.transpose(2, 0, 1)
63
 
64
  return {"pixel_values": [process_image(img) for img in examples["image"]]}
 
77
  processed_dataset = pickle.load(f)
78
  else:
79
  print("Loading dataset from Hugging Face...")
80
+ dataset = load_dataset(dataset_name, split="train[:500]") # Load only first 500 samples
81
  print("Processing dataset...")
82
  processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
83
  with open(dataset_cache_file, 'wb') as f:
 
94
  jax.clear_caches()
95
  gc.collect()
96
 
97
+ # Training function
 
98
  def train_step(state, batch, rng):
99
  def compute_loss(params, pixel_values, rng):
100
  latents = pipeline.vae.apply(
 
102
  pixel_values,
103
  method=pipeline.vae.encode
104
  ).latent_dist.sample(rng)
105
+ latents = latents * 0.18215
106
 
107
+ noise = jax.random.normal(rng, latents.shape)
108
  timesteps = jax.random.randint(
109
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
110
  )
 
117
 
118
  encoder_hidden_states = jax.random.normal(
119
  rng,
120
+ (latents.shape[0], pipeline.text_encoder.config.hidden_size)
 
121
  )
122
 
123
  model_output = state.apply_fn.apply(
 
137
  return state, loss
138
 
139
  # Initialize training state
140
+ learning_rate = 1e-5
141
  optimizer = optax.adam(learning_rate)
142
  state = train_state.TrainState.create(
143
  apply_fn=unet,
 
146
  )
147
 
148
  # Training loop
149
+ num_epochs = 3 # Further reduced number of epochs
150
+ batch_size = 2 # Reduced batch size for CPU
151
  rng = jax.random.PRNGKey(0)
152
 
153
  for epoch in range(num_epochs):
154
  epoch_loss = 0
155
  num_batches = 0
156
  for batch in tqdm(processed_dataset.batch(batch_size)):
157
+ batch['pixel_values'] = jnp.array(batch['pixel_values'])
158
  rng, step_rng = jax.random.split(rng)
159
  state, loss = train_step(state, batch, step_rng)
160
  epoch_loss += loss