Nan Xue commited on
Commit
4621597
·
1 Parent(s): 8305e16
Files changed (1) hide show
  1. scalelsd/ssl/misc/train_utils.py +7 -1
scalelsd/ssl/misc/train_utils.py CHANGED
@@ -42,7 +42,13 @@ def fix_seeds(random_seed):
42
 
43
  def load_scalelsd_model(ckpt_path, device='cuda'):
44
  """load model"""
45
- use_layer_scale = False if 'v1' in ckpt_path else True
 
 
 
 
 
 
46
 
47
  model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale)
48
  model = model.eval().to(device)
 
42
 
43
  def load_scalelsd_model(ckpt_path, device='cuda'):
44
  """load model"""
45
+ use_layer_scale = True
46
+ if os.path.basename(ckpt_path) == 'scalelsd-vitbase-v1-train-sa1b.pt':
47
+ use_layer_scale = False
48
+ elif os.path.basename(ckpt_path) == 'scalelsd-vitbase-v2-train-sa1b.pt':
49
+ use_layer_scale = True
50
+ else:
51
+ raise ValueError(f'Unknown model: {os.path.basename(ckpt_path)}')
52
 
53
  model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale)
54
  model = model.eval().to(device)