Update app.py
Browse files
app.py
CHANGED
@@ -53,40 +53,18 @@ pipeline, params = get_model(model_id, "flax")
|
|
53 |
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
54 |
pipeline.scheduler = custom_scheduler
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
"DownBlock2D"
|
67 |
-
]
|
68 |
-
config.up_block_types = [
|
69 |
-
"UpBlock2D",
|
70 |
-
"CrossAttnUpBlock2D",
|
71 |
-
"CrossAttnUpBlock2D",
|
72 |
-
"CrossAttnUpBlock2D"
|
73 |
-
]
|
74 |
-
return config
|
75 |
-
|
76 |
-
modified_unet_config = modify_unet_config(unet_config)
|
77 |
-
|
78 |
-
# Create a new UNet with the modified configuration
|
79 |
-
unet = FlaxUNet2DConditionModel(modified_unet_config)
|
80 |
-
|
81 |
-
# Initialize the new UNet with random weights
|
82 |
-
rng = jax.random.PRNGKey(0)
|
83 |
-
sample_input = jnp.ones((1, 64, 64, 4))
|
84 |
-
sample_t = jnp.ones((1,))
|
85 |
-
sample_encoder_hidden_states = jnp.ones((1, 77, 768))
|
86 |
-
new_unet_params = unet.init(rng, sample_input, sample_t, sample_encoder_hidden_states)["params"]
|
87 |
|
88 |
-
|
89 |
-
params["unet"] = new_unet_params
|
90 |
|
91 |
# Load and preprocess your dataset
|
92 |
def preprocess_images(examples):
|
@@ -124,10 +102,6 @@ print(f"Processed dataset size: {len(processed_dataset)}")
|
|
124 |
# Training function
|
125 |
def train_step(state, batch, rng):
|
126 |
def compute_loss(params, pixel_values, rng):
|
127 |
-
print("pixel_values dtype:", pixel_values.dtype)
|
128 |
-
print("params dtypes:", jax.tree_map(lambda x: x.dtype, params))
|
129 |
-
print("rng dtype:", rng.dtype)
|
130 |
-
|
131 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
132 |
|
133 |
latents = pipeline.vae.apply(
|
@@ -143,11 +117,6 @@ def train_step(state, batch, rng):
|
|
143 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
144 |
)
|
145 |
|
146 |
-
print("timesteps dtype:", timesteps.dtype)
|
147 |
-
print("latents dtype:", latents.dtype)
|
148 |
-
print("noise dtype:", noise.dtype)
|
149 |
-
print("latents shape:", latents.shape)
|
150 |
-
|
151 |
noisy_latents = pipeline.scheduler.add_noise(
|
152 |
pipeline.scheduler.create_state(),
|
153 |
original_samples=latents,
|
|
|
53 |
custom_scheduler = CustomFlaxPNDMScheduler.from_config(pipeline.scheduler.config)
|
54 |
pipeline.scheduler = custom_scheduler
|
55 |
|
56 |
+
# Extract UNet from pipeline
|
57 |
+
unet = pipeline.unet
|
58 |
+
|
59 |
+
# Adjust the input layer of the UNet
|
60 |
+
def adjust_unet_input_layer(params):
|
61 |
+
conv_in_weight = params['unet']['conv_in']['kernel']
|
62 |
+
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
63 |
+
new_conv_in_weight = new_conv_in_weight.at[:, :, :4, :].set(conv_in_weight[:, :, :4, :])
|
64 |
+
params['unet']['conv_in']['kernel'] = new_conv_in_weight
|
65 |
+
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
params = adjust_unet_input_layer(params)
|
|
|
68 |
|
69 |
# Load and preprocess your dataset
|
70 |
def preprocess_images(examples):
|
|
|
102 |
# Training function
|
103 |
def train_step(state, batch, rng):
|
104 |
def compute_loss(params, pixel_values, rng):
|
|
|
|
|
|
|
|
|
105 |
pixel_values = jnp.array(pixel_values, dtype=jnp.float32)
|
106 |
|
107 |
latents = pipeline.vae.apply(
|
|
|
117 |
rng, (latents.shape[0],), 0, pipeline.scheduler.config.num_train_timesteps
|
118 |
)
|
119 |
|
|
|
|
|
|
|
|
|
|
|
120 |
noisy_latents = pipeline.scheduler.add_noise(
|
121 |
pipeline.scheduler.create_state(),
|
122 |
original_samples=latents,
|