Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import jax
|
2 |
import jax.numpy as jnp
|
3 |
-
from flax.jax_utils import replicate
|
4 |
from flax.training import train_state
|
5 |
import optax
|
6 |
from diffusers import FlaxStableDiffusionPipeline
|
@@ -12,6 +11,11 @@ from PIL import Image
|
|
12 |
import numpy as np
|
13 |
import gc
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
# Set up cache directories
|
16 |
cache_dir = "/tmp/huggingface_cache"
|
17 |
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
|
@@ -33,7 +37,7 @@ def get_model(model_id, revision):
|
|
33 |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
34 |
model_id,
|
35 |
revision=revision,
|
36 |
-
dtype=jnp.
|
37 |
)
|
38 |
with open(model_cache_file, 'wb') as f:
|
39 |
pickle.dump((pipeline, params), f)
|
@@ -53,8 +57,8 @@ def preprocess_images(examples):
|
|
53 |
image = Image.open(image)
|
54 |
if not isinstance(image, Image.Image):
|
55 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
56 |
-
image = image.convert("RGB").resize((
|
57 |
-
image = np.array(image).astype(np.
|
58 |
return image.transpose(2, 0, 1)
|
59 |
|
60 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
@@ -73,7 +77,7 @@ try:
|
|
73 |
processed_dataset = pickle.load(f)
|
74 |
else:
|
75 |
print("Loading dataset from Hugging Face...")
|
76 |
-
dataset = load_dataset(dataset_name, split="train[:
|
77 |
print("Processing dataset...")
|
78 |
processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
|
79 |
with open(dataset_cache_file, 'wb') as f:
|
@@ -90,8 +94,7 @@ def clear_jit_cache():
|
|
90 |
jax.clear_caches()
|
91 |
gc.collect()
|
92 |
|
93 |
-
# Training function
|
94 |
-
@jax.jit
|
95 |
def train_step(state, batch, rng):
|
96 |
def compute_loss(params, pixel_values, rng):
|
97 |
latents = pipeline.vae.apply(
|
@@ -99,9 +102,9 @@ def train_step(state, batch, rng):
|
|
99 |
pixel_values,
|
100 |
method=pipeline.vae.encode
|
101 |
).latent_dist.sample(rng)
|
102 |
-
latents = latents *
|
103 |
|
104 |
-
noise = jax.random.normal(rng, latents.shape
|
105 |
timesteps = jax.random.randint(
|
106 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
107 |
)
|
@@ -114,8 +117,7 @@ def train_step(state, batch, rng):
|
|
114 |
|
115 |
encoder_hidden_states = jax.random.normal(
|
116 |
rng,
|
117 |
-
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
118 |
-
dtype=jnp.float16
|
119 |
)
|
120 |
|
121 |
model_output = state.apply_fn.apply(
|
@@ -135,7 +137,7 @@ def train_step(state, batch, rng):
|
|
135 |
return state, loss
|
136 |
|
137 |
# Initialize training state
|
138 |
-
learning_rate =
|
139 |
optimizer = optax.adam(learning_rate)
|
140 |
state = train_state.TrainState.create(
|
141 |
apply_fn=unet,
|
@@ -144,15 +146,15 @@ state = train_state.TrainState.create(
|
|
144 |
)
|
145 |
|
146 |
# Training loop
|
147 |
-
num_epochs =
|
148 |
-
batch_size =
|
149 |
rng = jax.random.PRNGKey(0)
|
150 |
|
151 |
for epoch in range(num_epochs):
|
152 |
epoch_loss = 0
|
153 |
num_batches = 0
|
154 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
155 |
-
batch['pixel_values'] = jnp.array(batch['pixel_values']
|
156 |
rng, step_rng = jax.random.split(rng)
|
157 |
state, loss = train_step(state, batch, step_rng)
|
158 |
epoch_loss += loss
|
|
|
1 |
import jax
|
2 |
import jax.numpy as jnp
|
|
|
3 |
from flax.training import train_state
|
4 |
import optax
|
5 |
from diffusers import FlaxStableDiffusionPipeline
|
|
|
11 |
import numpy as np
|
12 |
import gc
|
13 |
|
14 |
+
# Force JAX to use CPU
|
15 |
+
jax.config.update('jax_platform_name', 'cpu')
|
16 |
+
|
17 |
+
print("Using CPU for computations")
|
18 |
+
|
19 |
# Set up cache directories
|
20 |
cache_dir = "/tmp/huggingface_cache"
|
21 |
model_cache_dir = os.path.join(cache_dir, "stable_diffusion_model")
|
|
|
37 |
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
38 |
model_id,
|
39 |
revision=revision,
|
40 |
+
dtype=jnp.float32, # Use float32 for CPU
|
41 |
)
|
42 |
with open(model_cache_file, 'wb') as f:
|
43 |
pickle.dump((pipeline, params), f)
|
|
|
57 |
image = Image.open(image)
|
58 |
if not isinstance(image, Image.Image):
|
59 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
60 |
+
image = image.convert("RGB").resize((128, 128)) # Further reduced image size
|
61 |
+
image = np.array(image).astype(np.float32) / 255.0
|
62 |
return image.transpose(2, 0, 1)
|
63 |
|
64 |
return {"pixel_values": [process_image(img) for img in examples["image"]]}
|
|
|
77 |
processed_dataset = pickle.load(f)
|
78 |
else:
|
79 |
print("Loading dataset from Hugging Face...")
|
80 |
+
dataset = load_dataset(dataset_name, split="train[:500]") # Load only first 500 samples
|
81 |
print("Processing dataset...")
|
82 |
processed_dataset = dataset.map(preprocess_images, batched=True, remove_columns=dataset.column_names)
|
83 |
with open(dataset_cache_file, 'wb') as f:
|
|
|
94 |
jax.clear_caches()
|
95 |
gc.collect()
|
96 |
|
97 |
+
# Training function
|
|
|
98 |
def train_step(state, batch, rng):
|
99 |
def compute_loss(params, pixel_values, rng):
|
100 |
latents = pipeline.vae.apply(
|
|
|
102 |
pixel_values,
|
103 |
method=pipeline.vae.encode
|
104 |
).latent_dist.sample(rng)
|
105 |
+
latents = latents * 0.18215
|
106 |
|
107 |
+
noise = jax.random.normal(rng, latents.shape)
|
108 |
timesteps = jax.random.randint(
|
109 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
110 |
)
|
|
|
117 |
|
118 |
encoder_hidden_states = jax.random.normal(
|
119 |
rng,
|
120 |
+
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
|
|
121 |
)
|
122 |
|
123 |
model_output = state.apply_fn.apply(
|
|
|
137 |
return state, loss
|
138 |
|
139 |
# Initialize training state
|
140 |
+
learning_rate = 1e-5
|
141 |
optimizer = optax.adam(learning_rate)
|
142 |
state = train_state.TrainState.create(
|
143 |
apply_fn=unet,
|
|
|
146 |
)
|
147 |
|
148 |
# Training loop
|
149 |
+
num_epochs = 3 # Further reduced number of epochs
|
150 |
+
batch_size = 2 # Reduced batch size for CPU
|
151 |
rng = jax.random.PRNGKey(0)
|
152 |
|
153 |
for epoch in range(num_epochs):
|
154 |
epoch_loss = 0
|
155 |
num_batches = 0
|
156 |
for batch in tqdm(processed_dataset.batch(batch_size)):
|
157 |
+
batch['pixel_values'] = jnp.array(batch['pixel_values'])
|
158 |
rng, step_rng = jax.random.split(rng)
|
159 |
state, loss = train_step(state, batch, step_rng)
|
160 |
epoch_loss += loss
|