uruguayai commited on
Commit
967b314
·
verified ·
1 Parent(s): a96d1af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -42
app.py CHANGED
@@ -53,40 +53,18 @@ pipeline, params = get_model(model_id, "flax")
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
- # Modify UNet configuration
57
- unet_config = pipeline.unet.config
58
- unet_config.in_channels = 4 # Set to match the latent space dimensions
59
-
60
- # Modify the UNet architecture
61
- def modify_unet_config(config):
62
- config.down_block_types = [
63
- "CrossAttnDownBlock2D",
64
- "CrossAttnDownBlock2D",
65
- "CrossAttnDownBlock2D",
66
- "DownBlock2D"
67
- ]
68
- config.up_block_types = [
69
- "UpBlock2D",
70
- "CrossAttnUpBlock2D",
71
- "CrossAttnUpBlock2D",
72
- "CrossAttnUpBlock2D"
73
- ]
74
- return config
75
-
76
- modified_unet_config = modify_unet_config(unet_config)
77
-
78
- # Create a new UNet with the modified configuration
79
- unet = FlaxUNet2DConditionModel(modified_unet_config)
80
-
81
- # Initialize the new UNet with random weights
82
- rng = jax.random.PRNGKey(0)
83
- sample_input = jnp.ones((1, 64, 64, 4))
84
- sample_t = jnp.ones((1,))
85
- sample_encoder_hidden_states = jnp.ones((1, 77, 768))
86
- new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
87
 
88
- # Replace the UNet params in the pipeline
89
- params["unet"] = new_unet_params
90
 
91
  # Load and preprocess your dataset
92
  def preprocess_images(examples):
@@ -124,10 +102,6 @@ print(f"Processed dataset size: {len(processed_dataset)}")
124
  # Training function
125
  def train_step(state, batch, rng):
126
  def compute_loss(params, pixel_values, rng):
127
- print("pixel_values dtype:", pixel_values.dtype)
128
- print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
129
- print("rng dtype:", rng.dtype)
130
-
131
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
132
 
133
  latents = pipeline.vae.apply(
@@ -143,11 +117,6 @@ def train_step(state, batch, rng):
143
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
144
  )
145
 
146
- print("timesteps dtype:", timesteps.dtype)
147
- print("latents dtype:", latents.dtype)
148
- print("noise dtype:", noise.dtype)
149
- print("latents shape:", latents.shape)
150
-
151
  noisy_latents = pipeline.scheduler.add_noise(
152
  pipeline.scheduler.create_state(),
153
  original_samples=latents,
 
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
+ # Extract UNet from pipeline
57
+ unet = pipeline.unet
58
+
59
+ # Adjust the input layer of the UNet
60
+ def adjust_unet_input_layer(params):
61
+ conv_in_weight = params['unet']['conv_in']['kernel']
62
+ new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
63
+ new_conv_in_weight = new_conv_in_weight.at[:, :, :4, :].set(conv_in_weight[:, :, :4, :])
64
+ params['unet']['conv_in']['kernel'] = new_conv_in_weight
65
+ return params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ params = adjust_unet_input_layer(params)
 
68
 
69
  # Load and preprocess your dataset
70
  def preprocess_images(examples):
 
102
  # Training function
103
  def train_step(state, batch, rng):
104
  def compute_loss(params, pixel_values, rng):
 
 
 
 
105
  pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
106
 
107
  latents = pipeline.vae.apply(
 
117
  rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
118
  )
119
 
 
 
 
 
 
120
  noisy_latents = pipeline.scheduler.add_noise(
121
  pipeline.scheduler.create_state(),
122
  original_samples=latents,