Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
|
|
3 |
from flax.jax_utils import replicate
|
4 |
from flax.training import train_state
|
5 |
import optax
|
6 |
-
from diffusers import FlaxStableDiffusionPipeline
|
7 |
from diffusers.schedulers import FlaxPNDMScheduler
|
8 |
from datasets import load_dataset
|
9 |
from tqdm.auto import tqdm
|
@@ -53,8 +53,22 @@ pipeline, params = get_model(model_id, "flax")
|
|
53 |
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
54 |
pipeline.scheduler = custom_scheduler
|
55 |
|
56 |
-
#
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# Load and preprocess your dataset
|
60 |
def preprocess_images(examples):
|
@@ -129,11 +143,6 @@ def train_step(state, batch, rng):
|
|
129 |
dtype=jnp.float32
|
130 |
)
|
131 |
|
132 |
-
# Ensure noisy_latents has the correct number of channels
|
133 |
-
if noisy_latents.shape[-1] != pipeline.unet.config.in_channels:
|
134 |
-
pad_width = [(0, 0)] * (noisy_latents.ndim - 1) + [(0, pipeline.unet.config.in_channels - noisy_latents.shape[-1])]
|
135 |
-
noisy_latents = jnp.pad(noisy_latents, pad_width, mode='constant')
|
136 |
-
|
137 |
# Use the correct method to call the UNet
|
138 |
model_output = unet.apply(
|
139 |
{'params': params["unet"]},
|
|
|
3 |
from flax.jax_utils import replicate
|
4 |
from flax.training import train_state
|
5 |
import optax
|
6 |
+
from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
|
7 |
from diffusers.schedulers import FlaxPNDMScheduler
|
8 |
from datasets import load_dataset
|
9 |
from tqdm.auto import tqdm
|
|
|
53 |
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
54 |
pipeline.scheduler = custom_scheduler
|
55 |
|
56 |
+
# Modify UNet configuration
|
57 |
+
unet_config = pipeline.unet.config
|
58 |
+
unet_config.in_channels = 4 # Set to match the latent space dimensions
|
59 |
+
|
60 |
+
# Create a new UNet with the modified configuration
|
61 |
+
unet = FlaxUNet2DConditionModel(unet_config)
|
62 |
+
|
63 |
+
# Initialize the new UNet with random weights
|
64 |
+
rng = jax.random.PRNGKey(0)
|
65 |
+
sample_input = jnp.ones((1, 64, 64, 4))
|
66 |
+
sample_t = jnp.ones((1,))
|
67 |
+
sample_encoder_hidden_states = jnp.ones((1, 77, 768))
|
68 |
+
new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
|
69 |
+
|
70 |
+
# Replace the UNet params in the pipeline
|
71 |
+
params["unet"] = new_unet_params
|
72 |
|
73 |
# Load and preprocess your dataset
|
74 |
def preprocess_images(examples):
|
|
|
143 |
dtype=jnp.float32
|
144 |
)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
146 |
# Use the correct method to call the UNet
|
147 |
model_output = unet.apply(
|
148 |
{'params': params["unet"]},
|