uruguayai commited on
Commit
f17ed04
·
verified ·
1 Parent(s): 41e0af4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -74,13 +74,28 @@ print(unet.config)
74
 
75
  # Adjust the input layer of the UNet
76
  def adjust_unet_input_layer(params):
77
- conv_in_weight = params['conv_in']['kernel']
 
 
 
 
 
 
 
 
 
78
  print(f"Original conv_in weight shape: {conv_in_weight.shape}")
79
  if conv_in_weight.shape[2] != 4:
80
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
81
  new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
82
- params['conv_in']['kernel'] = new_conv_in_weight
83
- print(f"New conv_in weight shape: {params['conv_in']['kernel'].shape}")
 
 
 
 
 
 
84
  return params
85
 
86
  params = adjust_unet_input_layer(params)
 
74
 
75
  # Adjust the input layer of the UNet
76
  def adjust_unet_input_layer(params):
77
+ if 'unet' in params:
78
+ unet_params = params['unet']
79
+ else:
80
+ unet_params = params
81
+
82
+ if 'conv_in' not in unet_params:
83
+ print("Warning: 'conv_in' not found in UNet params. Skipping input layer adjustment.")
84
+ return params
85
+
86
+ conv_in_weight = unet_params['conv_in']['kernel']
87
  print(f"Original conv_in weight shape: {conv_in_weight.shape}")
88
  if conv_in_weight.shape[2] != 4:
89
  new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
90
  new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
91
+ unet_params['conv_in']['kernel'] = new_conv_in_weight
92
+ print(f"New conv_in weight shape: {unet_params['conv_in']['kernel'].shape}")
93
+
94
+ if 'unet' in params:
95
+ params['unet'] = unet_params
96
+ else:
97
+ params = unet_params
98
+
99
  return params
100
 
101
  params = adjust_unet_input_layer(params)