uruguayai commited on
Commit
ed67914
·
verified ·
1 Parent(s): 629ceb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -64,8 +64,11 @@ print(unet.config)
64
  def adjust_unet_input_layer(params):
65
  conv_in_weight = params['unet']['conv_in']['kernel']
66
  print(f"Original conv_in weight shape: {conv_in_weight.shape}")
67
- new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
68
- new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
 
 
 
69
  params['unet']['conv_in']['kernel'] = new_conv_in_weight
70
  print(f"New conv_in weight shape: {params['unet']['conv_in']['kernel'].shape}")
71
  return params
 
64
  def adjust_unet_input_layer(params):
65
  conv_in_weight = params['unet']['conv_in']['kernel']
66
  print(f"Original conv_in weight shape: {conv_in_weight.shape}")
67
+ if conv_in_weight.shape[2] == 64:
68
+ new_conv_in_weight = conv_in_weight[:, :, :4, :]
69
+ else:
70
+ new_conv_in_weight = jnp.zeros((3, 3, 4, 320), dtype=jnp.float32)
71
+ new_conv_in_weight = new_conv_in_weight.at[:, :, :3, :].set(conv_in_weight[:, :, :3, :])
72
  params['unet']['conv_in']['kernel'] = new_conv_in_weight
73
  print(f"New conv_in weight shape: {params['unet']['conv_in']['kernel'].shape}")
74
  return params