uruguayai commited on
Commit
ae48a93
·
verified ·
1 Parent(s): faf4066

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -17
app.py CHANGED
@@ -3,7 +3,7 @@ import jax.numpy as jnp
3
  from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
- from diffusers import FlaxStableDiffusionPipeline, FlaxUNet2DConditionModel
7
  from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
@@ -53,22 +53,8 @@ pipeline, params = get_model(model_id, "flax")
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
- # Modify UNet configuration
57
- unet_config = pipeline.unet.config
58
- unet_config.in_channels = 4 # Set to match the latent space dimensions
59
-
60
- # Create a new UNet with the modified configuration
61
- unet = FlaxUNet2DConditionModel(unet_config)
62
-
63
- # Initialize the new UNet with random weights
64
- rng = jax.random.PRNGKey(0)
65
- sample_input = jnp.ones((1, 64, 64, 4))
66
- sample_t = jnp.ones((1,))
67
- sample_encoder_hidden_states = jnp.ones((1, 77, 768))
68
- new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
69
-
70
- # Replace the UNet params in the pipeline
71
- params["unet"] = new_unet_params
72
 
73
  # Load and preprocess your dataset
74
  def preprocess_images(examples):
 
3
  from flax.jax_utils import replicate
4
  from flax.training import train_state
5
  import optax
6
+ from diffusers import FlaxStableDiffusionPipeline
7
  from diffusers.schedulers import FlaxPNDMScheduler
8
  from datasets import load_dataset
9
  from tqdm.auto import tqdm
 
53
  custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
54
  pipeline.scheduler = custom_scheduler
55
 
56
+ # Extract UNet from pipeline
57
+ unet = pipeline.unet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Load and preprocess your dataset
60
  def preprocess_images(examples):