uruguayai commited on
Commit
41e0af4
·
verified ·
1 Parent(s): cfafe9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -74,15 +74,13 @@ print(unet.config)
74
 
75
  # Adjust the input layer of the UNet
76
  def adjust_unet_input_layer(params):
77
- conv_in_weight = params['unet']['conv_in']['kernel']
78
  print(f"Original conv_in weight shape: {conv_in_weight.shape}")
79
- if conv_in_weight.shape[2] == 64:
80
- new_conv_in_weight = conv_in_weight[:, :, :4, :]
81
- else:
82
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
83
  new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
84
- params['unet']['conv_in']['kernel'] = new_conv_in_weight
85
- print(f"New conv_in weight shape: {params['unet']['conv_in']['kernel'].shape}")
86
  return params
87
 
88
  params = adjust_unet_input_layer(params)
@@ -205,7 +203,7 @@ print("Filtered UNet config keys:", filtered_unet_config.keys())
205
 
206
  adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
207
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
208
- adjusted_params = float32_params['unet'] # Use only UNet params
209
 
210
  state = train_state.TrainState.create(
211
  apply_fn=adjusted_unet.apply,
 
74
 
75
  # Adjust the input layer of the UNet
76
  def adjust_unet_input_layer(params):
77
+ conv_in_weight = params['conv_in']['kernel']
78
  print(f"Original conv_in weight shape: {conv_in_weight.shape}")
79
+ if conv_in_weight.shape[2] != 4:
 
 
80
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
81
  new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
82
+ params['conv_in']['kernel'] = new_conv_in_weight
83
+ print(f"New conv_in weight shape: {params['conv_in']['kernel'].shape}")
84
  return params
85
 
86
  params = adjust_unet_input_layer(params)
 
203
 
204
  adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
205
  adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
206
+ adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
207
 
208
  state = train_state.TrainState.create(
209
  apply_fn=adjusted_unet.apply,