uruguayai commited on
Commit
66bb520
·
verified ·
1 Parent(s): 7166f76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -56,12 +56,18 @@ pipeline.scheduler = custom_scheduler
56
  # Extract UNet from pipeline
57
  unet = pipeline.unet
58
 
 
 
 
 
59
  # Adjust the input layer of the UNet
60
  def adjust_unet_input_layer(params):
61
  conv_in_weight = params['unet']['conv_in']['kernel']
 
62
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
63
  new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
64
  params['unet']['conv_in']['kernel'] = new_conv_in_weight
 
65
  return params
66
 
67
  params = adjust_unet_input_layer(params)
@@ -103,6 +109,10 @@ else:
103
 
104
  print(f"Processed dataset size: {len(processed_dataset)}")
105
 
 
 
 
 
106
  # Training function
107
  def train_step(state, batch, rng):
108
  def compute_loss(params, pixel_values, rng):
@@ -134,6 +144,10 @@ def train_step(state, batch, rng):
134
  dtype=jnp.float32
135
  )
136
 
 
 
 
 
137
  # Use the correct method to call the UNet
138
  model_output = unet.apply(
139
  {'params': params["unet"]},
 
56
  # Extract UNet from pipeline
57
  unet = pipeline.unet
58
 
59
+ # Print UNet configuration
60
+ print("UNet configuration:")
61
+ print(unet.config)
62
+
63
  # Adjust the input layer of the UNet
64
  def adjust_unet_input_layer(params):
65
  conv_in_weight = params['unet']['conv_in']['kernel']
66
+ print(f"Original conv_in weight shape: {conv_in_weight.shape}")
67
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
68
  new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
69
  params['unet']['conv_in']['kernel'] = new_conv_in_weight
70
+ print(f"New conv_in weight shape: {params['unet']['conv_in']['kernel'].shape}")
71
  return params
72
 
73
  params = adjust_unet_input_layer(params)
 
109
 
110
  print(f"Processed dataset size: {len(processed_dataset)}")
111
 
112
+ # Print sample input shape
113
+ sample_batch = next(iter(processed_dataset.batch(1)))
114
+ print(f"Sample input shape: {sample_batch['pixel_values'].shape}")
115
+
116
  # Training function
117
  def train_step(state, batch, rng):
118
  def compute_loss(params, pixel_values, rng):
 
144
  dtype=jnp.float32
145
  )
146
 
147
+ print(f"noisy_latents shape: {noisy_latents.shape}")
148
+ print(f"timesteps shape: {timesteps.shape}")
149
+ print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}")
150
+
151
  # Use the correct method to call the UNet
152
  model_output = unet.apply(
153
  {'params': params["unet"]},