Update app.py
Browse files
app.py
CHANGED
@@ -57,7 +57,7 @@ def preprocess_images(examples):
|
|
57 |
image = Image.open(image)
|
58 |
if not isinstance(image, Image.Image):
|
59 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
60 |
-
image = image.convert("RGB").resize((
|
61 |
image = np.array(image).astype(np.float32) / 255.0
|
62 |
return image.transpose(2, 0, 1)
|
63 |
|
@@ -97,6 +97,7 @@ def clear_jit_cache():
|
|
97 |
# Training function
|
98 |
def train_step(state, batch, rng):
|
99 |
def compute_loss(params, pixel_values, rng):
|
|
|
100 |
latents = pipeline.vae.apply(
|
101 |
{"params": params["vae"]},
|
102 |
pixel_values,
|
@@ -104,10 +105,15 @@ def train_step(state, batch, rng):
|
|
104 |
).latent_dist.sample(rng)
|
105 |
latents = latents * 0.18215
|
106 |
|
|
|
107 |
noise = jax.random.normal(rng, latents.shape)
|
|
|
|
|
108 |
timesteps = jax.random.randint(
|
109 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
110 |
)
|
|
|
|
|
111 |
noisy_latents = pipeline.scheduler.add_noise(
|
112 |
pipeline.scheduler.create_state(),
|
113 |
original_samples=latents,
|
@@ -115,11 +121,13 @@ def train_step(state, batch, rng):
|
|
115 |
timesteps=timesteps
|
116 |
)
|
117 |
|
|
|
118 |
encoder_hidden_states = jax.random.normal(
|
119 |
rng,
|
120 |
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
121 |
)
|
122 |
|
|
|
123 |
model_output = state.apply_fn.apply(
|
124 |
{'params': params["unet"]},
|
125 |
noisy_latents,
|
@@ -128,6 +136,7 @@ def train_step(state, batch, rng):
|
|
128 |
train=True,
|
129 |
)
|
130 |
|
|
|
131 |
return jnp.mean((model_output - noise) ** 2)
|
132 |
|
133 |
grad_fn = jax.value_and_grad(compute_loss)
|
@@ -136,18 +145,9 @@ def train_step(state, batch, rng):
|
|
136 |
state = state.apply_gradients(grads=grads)
|
137 |
return state, loss
|
138 |
|
139 |
-
# Initialize training state
|
140 |
-
learning_rate = 1e-5
|
141 |
-
optimizer = optax.adam(learning_rate)
|
142 |
-
state = train_state.TrainState.create(
|
143 |
-
apply_fn=unet,
|
144 |
-
params={"unet": params["unet"], "vae": params["vae"]},
|
145 |
-
tx=optimizer,
|
146 |
-
)
|
147 |
-
|
148 |
# Training loop
|
149 |
-
num_epochs = 3
|
150 |
-
batch_size =
|
151 |
rng = jax.random.PRNGKey(0)
|
152 |
|
153 |
for epoch in range(num_epochs):
|
|
|
57 |
image = Image.open(image)
|
58 |
if not isinstance(image, Image.Image):
|
59 |
raise ValueError(f"Unexpected image type: {type(image)}")
|
60 |
+
image = image.convert("RGB").resize((512, 512)) # Keep original size
|
61 |
image = np.array(image).astype(np.float32) / 255.0
|
62 |
return image.transpose(2, 0, 1)
|
63 |
|
|
|
97 |
# Training function
|
98 |
def train_step(state, batch, rng):
|
99 |
def compute_loss(params, pixel_values, rng):
|
100 |
+
# Encode images to latent space
|
101 |
latents = pipeline.vae.apply(
|
102 |
{"params": params["vae"]},
|
103 |
pixel_values,
|
|
|
105 |
).latent_dist.sample(rng)
|
106 |
latents = latents * 0.18215
|
107 |
|
108 |
+
# Generate random noise
|
109 |
noise = jax.random.normal(rng, latents.shape)
|
110 |
+
|
111 |
+
# Sample random timesteps
|
112 |
timesteps = jax.random.randint(
|
113 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
114 |
)
|
115 |
+
|
116 |
+
# Add noise to latents
|
117 |
noisy_latents = pipeline.scheduler.add_noise(
|
118 |
pipeline.scheduler.create_state(),
|
119 |
original_samples=latents,
|
|
|
121 |
timesteps=timesteps
|
122 |
)
|
123 |
|
124 |
+
# Generate random encoder hidden states (simulating text embeddings)
|
125 |
encoder_hidden_states = jax.random.normal(
|
126 |
rng,
|
127 |
(latents.shape[0], pipeline.text_encoder.config.hidden_size)
|
128 |
)
|
129 |
|
130 |
+
# Predict noise
|
131 |
model_output = state.apply_fn.apply(
|
132 |
{'params': params["unet"]},
|
133 |
noisy_latents,
|
|
|
136 |
train=True,
|
137 |
)
|
138 |
|
139 |
+
# Compute loss
|
140 |
return jnp.mean((model_output - noise) ** 2)
|
141 |
|
142 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
145 |
state = state.apply_gradients(grads=grads)
|
146 |
return state, loss
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
# Training loop
|
149 |
+
num_epochs = 3
|
150 |
+
batch_size = 1 # Reduced batch size due to memory constraints
|
151 |
rng = jax.random.PRNGKey(0)
|
152 |
|
153 |
for epoch in range(num_epochs):
|