uruguayai commited on
Commit
a96d1af
·
verified ·
1 Parent(s): 0b99dda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -57,8 +57,26 @@ pipeline.scheduler = custom_scheduler
57
  unet_config = pipeline.unet.config
58
  unet_config.in_channels = 4 # Set to match the latent space dimensions
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Create a new UNet with the modified configuration
61
- unet = FlaxUNet2DConditionModel(unet_config)
62
 
63
  # Initialize the new UNet with random weights
64
  rng = jax.random.PRNGKey(0)
 
57
  unet_config = pipeline.unet.config
58
  unet_config.in_channels = 4 # Set to match the latent space dimensions
59
 
60
+ # Modify the UNet architecture
61
+ def modify_unet_config(config):
62
+ config.down_block_types = [
63
+ "CrossAttnDownBlock2D",
64
+ "CrossAttnDownBlock2D",
65
+ "CrossAttnDownBlock2D",
66
+ "DownBlock2D"
67
+ ]
68
+ config.up_block_types = [
69
+ "UpBlock2D",
70
+ "CrossAttnUpBlock2D",
71
+ "CrossAttnUpBlock2D",
72
+ "CrossAttnUpBlock2D"
73
+ ]
74
+ return config
75
+
76
+ modified_unet_config = modify_unet_config(unet_config)
77
+
78
  # Create a new UNet with the modified configuration
79
+ unet = FlaxUNet2DConditionModel(modified_unet_config)
80
 
81
  # Initialize the new UNet with random weights
82
  rng = jax.random.PRNGKey(0)