Update app.py
Browse files
app.py
CHANGED
@@ -57,8 +57,26 @@ pipeline.scheduler = custom_scheduler
|
|
57 |
unet_config = pipeline.unet.config
|
58 |
unet_config.in_channels = 4 # Set to match the latent space dimensions
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# Create a new UNet with the modified configuration
|
61 |
-
unet = FlaxUNet2DConditionModel(
|
62 |
|
63 |
# Initialize the new UNet with random weights
|
64 |
rng = jax.random.PRNGKey(0)
|
|
|
57 |
unet_config = pipeline.unet.config
|
58 |
unet_config.in_channels = 4 # Set to match the latent space dimensions
|
59 |
|
60 |
+
# Modify the UNet architecture
|
61 |
+
def modify_unet_config(config):
|
62 |
+
config.down_block_types = [
|
63 |
+
"CrossAttnDownBlock2D",
|
64 |
+
"CrossAttnDownBlock2D",
|
65 |
+
"CrossAttnDownBlock2D",
|
66 |
+
"DownBlock2D"
|
67 |
+
]
|
68 |
+
config.up_block_types = [
|
69 |
+
"UpBlock2D",
|
70 |
+
"CrossAttnUpBlock2D",
|
71 |
+
"CrossAttnUpBlock2D",
|
72 |
+
"CrossAttnUpBlock2D"
|
73 |
+
]
|
74 |
+
return config
|
75 |
+
|
76 |
+
modified_unet_config = modify_unet_config(unet_config)
|
77 |
+
|
78 |
# Create a new UNet with the modified configuration
|
79 |
+
unet = FlaxUNet2DConditionModel(modified_unet_config)
|
80 |
|
81 |
# Initialize the new UNet with random weights
|
82 |
rng = jax.random.PRNGKey(0)
|