Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
|
|
3 |
from flax.jax_utils import replicate
|
4 |
from flax.training import train_state
|
5 |
import optax
|
6 |
-
from diffusers import FlaxStableDiffusionPipeline
|
7 |
from diffusers.schedulers import FlaxPNDMScheduler
|
8 |
from datasets import load_dataset
|
9 |
from tqdm.auto import tqdm
|
@@ -53,24 +53,22 @@ pipeline, params = get_model(model_id, "flax")
|
|
53 |
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
54 |
pipeline.scheduler = custom_scheduler
|
55 |
|
56 |
-
#
|
57 |
-
|
|
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
# Update the parameters
|
69 |
-
params['unet']['conv_in']['kernel'] = new_weight
|
70 |
-
|
71 |
-
return params
|
72 |
|
73 |
-
params
|
|
|
74 |
|
75 |
# Load and preprocess your dataset
|
76 |
def preprocess_images(examples):
|
|
|
3 |
from flax.jax_utils import replicate
|
4 |
from flax.training import train_state
|
5 |
import optax
|
6 |
+
from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
|
7 |
from diffusers.schedulers import FlaxPNDMScheduler
|
8 |
from datasets import load_dataset
|
9 |
from tqdm.auto import tqdm
|
|
|
53 |
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
54 |
pipeline.scheduler = custom_scheduler
|
55 |
|
56 |
+
# Modify UNet configuration
|
57 |
+
unet_config = pipeline.unet.config
|
58 |
+
unet_config.in_channels = 4 # Set to match the latent space dimensions
|
59 |
|
60 |
+
# Create a new UNet with the modified configuration
|
61 |
+
unet = FlaxUNet2DConditionModel(unet_config)
|
62 |
+
|
63 |
+
# Initialize the new UNet with random weights
|
64 |
+
rng = jax.random.PRNGKey(0)
|
65 |
+
sample_input = jnp.ones((1, 64, 64, 4))
|
66 |
+
sample_t = jnp.ones((1,))
|
67 |
+
sample_encoder_hidden_states = jnp.ones((1, 77, 768))
|
68 |
+
new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
# Replace the UNet params in the pipeline
|
71 |
+
params["unet"] = new_unet_params
|
72 |
|
73 |
# Load and preprocess your dataset
|
74 |
def preprocess_images(examples):
|