chirmy commited on
Commit
3a5d86c
·
verified ·
1 Parent(s): b1f9ac8

Update script.py

Browse files

Annotate code:
# num_features = model.get_classifier().in_features
# model.classifier = nn.Linear(num_features, number_of_categories)

Files changed (1) hide show
  1. script.py +2 -2
script.py CHANGED
@@ -36,8 +36,8 @@ class PytorchWorker:
36
  model.load_state_dict(model_ckpt, strict=False)
37
  msg = model.load_state_dict(model_ckpt, strict=False)
38
  print("load_state_dict: ", msg)
39
- num_features = model.get_classifier().in_features
40
- model.classifier = nn.Linear(num_features, number_of_categories)
41
 
42
  return model.to(self.device).eval()
43
 
 
36
  model.load_state_dict(model_ckpt, strict=False)
37
  msg = model.load_state_dict(model_ckpt, strict=False)
38
  print("load_state_dict: ", msg)
39
+ # num_features = model.get_classifier().in_features
40
+ # model.classifier = nn.Linear(num_features, number_of_categories)
41
 
42
  return model.to(self.device).eval()
43