Update script.py
Browse filesAnnotate code:
# num_features = model.get_classifier().in_features
# model.classifier = nn.Linear(num_features, number_of_categories)
script.py
CHANGED
@@ -36,8 +36,8 @@ class PytorchWorker:
|
|
36 |
model.load_state_dict(model_ckpt, strict=False)
|
37 |
msg = model.load_state_dict(model_ckpt, strict=False)
|
38 |
print("load_state_dict: ", msg)
|
39 |
-
num_features = model.get_classifier().in_features
|
40 |
-
model.classifier = nn.Linear(num_features, number_of_categories)
|
41 |
|
42 |
return model.to(self.device).eval()
|
43 |
|
|
|
36 |
model.load_state_dict(model_ckpt, strict=False)
|
37 |
msg = model.load_state_dict(model_ckpt, strict=False)
|
38 |
print("load_state_dict: ", msg)
|
39 |
+
# num_features = model.get_classifier().in_features
|
40 |
+
# model.classifier = nn.Linear(num_features, number_of_categories)
|
41 |
|
42 |
return model.to(self.device).eval()
|
43 |
|