Merge pull request #79 from Lornatang/delete-redundant-thrid-party-library
Browse files- utils/torch_utils.py +12 -7
utils/torch_utils.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
import torch.backends.cudnn as cudnn
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
|
|
10 |
|
11 |
|
12 |
def init_seeds(seed=0):
|
@@ -120,18 +121,22 @@ def model_info(model, verbose=False):
|
|
120 |
|
121 |
def load_classifier(name='resnet101', n=2):
|
122 |
# Loads a pretrained model reshaped to n-class output
|
123 |
-
|
124 |
-
model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
|
125 |
|
126 |
# Display model properties
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
128 |
print(x + ' =', eval(x))
|
129 |
|
130 |
# Reshape output to n classes
|
131 |
-
filters = model.
|
132 |
-
model.
|
133 |
-
model.
|
134 |
-
model.
|
135 |
return model
|
136 |
|
137 |
|
|
|
7 |
import torch.backends.cudnn as cudnn
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
+
import torchvision.models as models
|
11 |
|
12 |
|
13 |
def init_seeds(seed=0):
|
|
|
121 |
|
122 |
def load_classifier(name='resnet101', n=2):
|
123 |
# Loads a pretrained model reshaped to n-class output
|
124 |
+
model = models.__dict__[name](pretrained=True)
|
|
|
125 |
|
126 |
# Display model properties
|
127 |
+
input_size = [3, 224, 224]
|
128 |
+
input_space = 'RGB'
|
129 |
+
input_range = [0, 1]
|
130 |
+
mean = [0.485, 0.456, 0.406]
|
131 |
+
std = [0.229, 0.224, 0.225]
|
132 |
+
for x in [input_size, input_space, input_range, mean, std]:
|
133 |
print(x + ' =', eval(x))
|
134 |
|
135 |
# Reshape output to n classes
|
136 |
+
filters = model.fc.weight.shape[1]
|
137 |
+
model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
|
138 |
+
model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
|
139 |
+
model.fc.out_features = n
|
140 |
return model
|
141 |
|
142 |
|