uruguayai commited on
Commit
0b99dda
·
verified ·
1 Parent(s): 16dd569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
3
  from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
- from diffusers import FlaxStableDiffusionPipeline
7
  from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
@@ -53,24 +53,22 @@ pipeline, params = get_model(model_id, "flax")
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
- # Extract UNet from pipeline
57
- unet = pipeline.unet
 
58
 
59
- # Modify the UNet's input layer
60
- def modify_unet_input_layer(params):
61
- conv_in_weight = params['unet']['conv_in']['kernel']
62
- conv_in_bias = params['unet']['conv_in']['bias']
63
-
64
- # Adjust the weight tensor
65
- new_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
66
- new_weight = new_weight.at[:, :, :4, :].set(conv_in_weight[:, :, :4, :])
67
-
68
- # Update the parameters
69
- params['unet']['conv_in']['kernel'] = new_weight
70
-
71
- return params
72
 
73
- params = modify_unet_input_layer(params)
 
74
 
75
  # Load and preprocess your dataset
76
  def preprocess_images(examples):
 
3
  from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
+ from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
7
  from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
 
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
+ # Modify UNet configuration
57
+ unet_config = pipeline.unet.config
58
+ unet_config.in_channels = 4 # Set to match the latent space dimensions
59
 
60
+ # Create a new UNet with the modified configuration
61
+ unet = FlaxUNet2DConditionModel(unet_config)
62
+
63
+ # Initialize the new UNet with random weights
64
+ rng = jax.random.PRNGKey(0)
65
+ sample_input = jnp.ones((1, 64, 64, 4))
66
+ sample_t = jnp.ones((1,))
67
+ sample_encoder_hidden_states = jnp.ones((1, 77, 768))
68
+ new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
 
 
 
 
69
 
70
+ # Replace the UNet params in the pipeline
71
+ params["unet"] = new_unet_params
72
 
73
  # Load and preprocess your dataset
74
  def preprocess_images(examples):