Update app.py
Browse files
app.py
CHANGED
@@ -74,15 +74,13 @@ print(unet.config)
|
|
74 |
|
75 |
# Adjust the input layer of the UNet
|
76 |
def adjust_unet_input_layer(params):
|
77 |
-
conv_in_weight = params['
|
78 |
print(f"Original conv_in weight shape: {conv_in_weight.shape}")
|
79 |
-
if conv_in_weight.shape[2]
|
80 |
-
new_conv_in_weight = conv_in_weight[:, :, :4, :]
|
81 |
-
else:
|
82 |
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
83 |
new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
|
84 |
-
|
85 |
-
print(f"New conv_in weight shape: {params['
|
86 |
return params
|
87 |
|
88 |
params = adjust_unet_input_layer(params)
|
@@ -205,7 +203,7 @@ print("Filtered UNet config keys:", filtered_unet_config.keys())
|
|
205 |
|
206 |
adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
|
207 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
208 |
-
adjusted_params =
|
209 |
|
210 |
state = train_state.TrainState.create(
|
211 |
apply_fn=adjusted_unet.apply,
|
|
|
74 |
|
75 |
# Adjust the input layer of the UNet
|
76 |
def adjust_unet_input_layer(params):
|
77 |
+
conv_in_weight = params['conv_in']['kernel']
|
78 |
print(f"Original conv_in weight shape: {conv_in_weight.shape}")
|
79 |
+
if conv_in_weight.shape[2] != 4:
|
|
|
|
|
80 |
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
81 |
new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
|
82 |
+
params['conv_in']['kernel'] = new_conv_in_weight
|
83 |
+
print(f"New conv_in weight shape: {params['conv_in']['kernel'].shape}")
|
84 |
return params
|
85 |
|
86 |
params = adjust_unet_input_layer(params)
|
|
|
203 |
|
204 |
adjusted_unet = FlaxUNet2DConditionModel(**filtered_unet_config)
|
205 |
adjusted_params = adjusted_unet.init(jax.random.PRNGKey(0), jnp.ones((1, 4, 64, 64)), jnp.ones((1,)), jnp.ones((1, 77, 768)))
|
206 |
+
adjusted_params = adjust_unet_input_layer(adjusted_params) # Adjust the input layer
|
207 |
|
208 |
state = train_state.TrainState.create(
|
209 |
apply_fn=adjusted_unet.apply,
|