Spaces:
Running
on
Zero
Running
on
Zero
Nan Xue
commited on
Commit
·
4621597
1
Parent(s):
8305e16
update
Browse files
scalelsd/ssl/misc/train_utils.py
CHANGED
@@ -42,7 +42,13 @@ def fix_seeds(random_seed):
|
|
42 |
|
43 |
def load_scalelsd_model(ckpt_path, device='cuda'):
|
44 |
"""load model"""
|
45 |
-
use_layer_scale =
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale)
|
48 |
model = model.eval().to(device)
|
|
|
42 |
|
43 |
def load_scalelsd_model(ckpt_path, device='cuda'):
|
44 |
"""load model"""
|
45 |
+
use_layer_scale = True
|
46 |
+
if os.path.basename(ckpt_path) == 'scalelsd-vitbase-v1-train-sa1b.pt':
|
47 |
+
use_layer_scale = False
|
48 |
+
elif os.path.basename(ckpt_path) == 'scalelsd-vitbase-v2-train-sa1b.pt':
|
49 |
+
use_layer_scale = True
|
50 |
+
else:
|
51 |
+
raise ValueError(f'Unknown model: {os.path.basename(ckpt_path)}')
|
52 |
|
53 |
model = ScaleLSD(gray_scale=True, use_layer_scale=use_layer_scale)
|
54 |
model = model.eval().to(device)
|