Update app.py
Browse files
app.py
CHANGED
@@ -56,12 +56,18 @@ pipeline.scheduler = custom_scheduler
|
|
56 |
# Extract UNet from pipeline
|
57 |
unet = pipeline.unet
|
58 |
|
|
|
|
|
|
|
|
|
59 |
# Adjust the input layer of the UNet
|
60 |
def adjust_unet_input_layer(params):
|
61 |
conv_in_weight = params['unet']['conv_in']['kernel']
|
|
|
62 |
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
63 |
new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
|
64 |
params['unet']['conv_in']['kernel'] = new_conv_in_weight
|
|
|
65 |
return params
|
66 |
|
67 |
params = adjust_unet_input_layer(params)
|
@@ -103,6 +109,10 @@ else:
|
|
103 |
|
104 |
print(f"Processed dataset size: {len(processed_dataset)}")
|
105 |
|
|
|
|
|
|
|
|
|
106 |
# Training function
|
107 |
def train_step(state, batch, rng):
|
108 |
def compute_loss(params, pixel_values, rng):
|
@@ -134,6 +144,10 @@ def train_step(state, batch, rng):
|
|
134 |
dtype=jnp.float32
|
135 |
)
|
136 |
|
|
|
|
|
|
|
|
|
137 |
# Use the correct method to call the UNet
|
138 |
model_output = unet.apply(
|
139 |
{'params': params["unet"]},
|
|
|
56 |
# Extract UNet from pipeline
|
57 |
unet = pipeline.unet
|
58 |
|
59 |
+
# Print UNet configuration
|
60 |
+
print("UNet configuration:")
|
61 |
+
print(unet.config)
|
62 |
+
|
63 |
# Adjust the input layer of the UNet
|
64 |
def adjust_unet_input_layer(params):
|
65 |
conv_in_weight = params['unet']['conv_in']['kernel']
|
66 |
+
print(f"Original conv_in weight shape: {conv_in_weight.shape}")
|
67 |
new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
|
68 |
new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
|
69 |
params['unet']['conv_in']['kernel'] = new_conv_in_weight
|
70 |
+
print(f"New conv_in weight shape: {params['unet']['conv_in']['kernel'].shape}")
|
71 |
return params
|
72 |
|
73 |
params = adjust_unet_input_layer(params)
|
|
|
109 |
|
110 |
print(f"Processed dataset size: {len(processed_dataset)}")
|
111 |
|
112 |
+
# Print sample input shape
|
113 |
+
sample_batch = next(iter(processed_dataset.batch(1)))
|
114 |
+
print(f"Sample input shape: {sample_batch['pixel_values'].shape}")
|
115 |
+
|
116 |
# Training function
|
117 |
def train_step(state, batch, rng):
|
118 |
def compute_loss(params, pixel_values, rng):
|
|
|
144 |
dtype=jnp.float32
|
145 |
)
|
146 |
|
147 |
+
print(f"noisy_latents shape: {noisy_latents.shape}")
|
148 |
+
print(f"timesteps shape: {timesteps.shape}")
|
149 |
+
print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
|
150 |
+
|
151 |
# Use the correct method to call the UNet
|
152 |
model_output = unet.apply(
|
153 |
{'params': params["unet"]},
|