Commit
·
124f0e8
1
Parent(s):
66676eb
torchvision nms bug fix
Browse files- utils/torch_utils.py +1 -1
utils/torch_utils.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
import torch.backends.cudnn as cudnn
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
|
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
@@ -151,7 +152,6 @@ def model_info(model, verbose=False):
|
|
151 |
|
152 |
def load_classifier(name='resnet101', n=2):
|
153 |
# Loads a pretrained model reshaped to n-class output
|
154 |
-
import torchvision
|
155 |
model = torchvision.models.__dict__[name](pretrained=True)
|
156 |
|
157 |
# ResNet model properties
|
|
|
8 |
import torch.backends.cudnn as cudnn
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
+
import torchvision
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
152 |
|
153 |
def load_classifier(name='resnet101', n=2):
|
154 |
# Loads a pretrained model reshaped to n-class output
|
|
|
155 |
model = torchvision.models.__dict__[name](pretrained=True)
|
156 |
|
157 |
# ResNet model properties
|