uruguayai commited on
Commit
faf4066
·
verified ·
1 Parent(s): 76dfe67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
3
  from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
- from diffusers import FlaxStableDiffusionPipeline
7
  from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
@@ -53,8 +53,22 @@ pipeline, params = get_model(model_id, "flax")
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
- # Extract UNet from pipeline
57
- unet = pipeline.unet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Load and preprocess your dataset
60
  def preprocess_images(examples):
@@ -129,11 +143,6 @@ def train_step(state, batch, rng):
129
  dtype=jnp.float32
130
  )
131
 
132
- # Ensure noisy_latents has the correct number of channels
133
- if noisy_latents.shape[-1] != pipeline.unet.config.in_channels:
134
- pad_width = [(0, 0)] * (noisy_latents.ndim - 1) + [(0, pipeline.unet.config.in_channels - noisy_latents.shape[-1])]
135
- noisy_latents = jnp.pad(noisy_latents, pad_width, mode='constant')
136
-
137
  # Use the correct method to call the UNet
138
  model_output = unet.apply(
139
  {'params': params["unet"]},
 
3
  from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
+ from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
7
  from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
 
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
+ # Modify UNet configuration
57
+ unet_config = pipeline.unet.config
58
+ unet_config.in_channels = 4 # Set to match the latent space dimensions
59
+
60
+ # Create a new UNet with the modified configuration
61
+ unet = FlaxUNet2DConditionModel(unet_config)
62
+
63
+ # Initialize the new UNet with random weights
64
+ rng = jax.random.PRNGKey(0)
65
+ sample_input = jnp.ones((1, 64, 64, 4))
66
+ sample_t = jnp.ones((1,))
67
+ sample_encoder_hidden_states = jnp.ones((1, 77, 768))
68
+ new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
69
+
70
+ # Replace the UNet params in the pipeline
71
+ params["unet"] = new_unet_params
72
 
73
  # Load and preprocess your dataset
74
  def preprocess_images(examples):
 
143
  dtype=jnp.float32
144
  )
145
 
 
 
 
 
 
146
  # Use the correct method to call the UNet
147
  model_output = unet.apply(
148
  {'params': params["unet"]},