Spaces:
Runtime error
Runtime error
echen01
commited on
Commit
•
f7bf9fb
1
Parent(s):
5b7158a
fix deeplab device
Browse files- criteria/deeplab.py +2 -2
- criteria/mask.py +1 -0
criteria/deeplab.py
CHANGED
@@ -309,7 +309,7 @@ def resnet50(pretrained=False, **kwargs):
|
|
309 |
return model
|
310 |
|
311 |
|
312 |
-
def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, **kwargs):
|
313 |
"""Constructs a ResNet-101 model.
|
314 |
|
315 |
Args:
|
@@ -326,7 +326,7 @@ def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, **
|
|
326 |
model_dict = model.state_dict()
|
327 |
if num_groups and weight_std:
|
328 |
path = os.path.join(os.path.dirname(path), "R-101-GN-WS.pth.tar")
|
329 |
-
pretrained_dict = torch.load(path)
|
330 |
overlap_dict = {
|
331 |
k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict
|
332 |
}
|
|
|
309 |
return model
|
310 |
|
311 |
|
312 |
+
def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, device="cpu", **kwargs):
|
313 |
"""Constructs a ResNet-101 model.
|
314 |
|
315 |
Args:
|
|
|
326 |
model_dict = model.state_dict()
|
327 |
if num_groups and weight_std:
|
328 |
path = os.path.join(os.path.dirname(path), "R-101-GN-WS.pth.tar")
|
329 |
+
pretrained_dict = torch.load(path, map_location=device)
|
330 |
overlap_dict = {
|
331 |
k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict
|
332 |
}
|
criteria/mask.py
CHANGED
@@ -36,6 +36,7 @@ class Mask(nn.Module):
|
|
36 |
num_groups=32,
|
37 |
weight_std=True,
|
38 |
beta=False,
|
|
|
39 |
)
|
40 |
.eval()
|
41 |
.requires_grad_(False)
|
|
|
36 |
num_groups=32,
|
37 |
weight_std=True,
|
38 |
beta=False,
|
39 |
+
device=device,
|
40 |
)
|
41 |
.eval()
|
42 |
.requires_grad_(False)
|