uruguayai commited on
Commit
16dd569
·
verified ·
1 Parent(s): ae48a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -0
app.py CHANGED
@@ -56,6 +56,22 @@ pipeline.scheduler = custom_scheduler
56
  # Extract UNet from pipeline
57
  unet = pipeline.unet
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # Load and preprocess your dataset
60
  def preprocess_images(examples):
61
  def process_image(image):
 
56
  # Extract UNet from pipeline
57
  unet = pipeline.unet
58
 
59
+ # Modify the UNet's input layer
60
+ def modify_unet_input_layer(params):
61
+ conv_in_weight = params['unet']['conv_in']['kernel']
62
+ conv_in_bias = params['unet']['conv_in']['bias']
63
+
64
+ # Adjust the weight tensor
65
+ new_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
66
+ new_weight = new_weight.at[:, :, :4, :].set(conv_in_weight[:, :, :4, :])
67
+
68
+ # Update the parameters
69
+ params['unet']['conv_in']['kernel'] = new_weight
70
+
71
+ return params
72
+
73
+ params = modify_unet_input_layer(params)
74
+
75
  # Load and preprocess your dataset
76
  def preprocess_images(examples):
77
  def process_image(image):