Maverick98 commited on
Commit
d478ff8
·
verified ·
1 Parent(s): 0bca9cd
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -56,12 +56,18 @@ checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve
56
  checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
57
 
58
  # Strip the "module." prefix from the keys in the state_dict if they exist
 
 
59
  new_state_dict = {}
60
- for k, v in checkpoint.items():
61
  if k.startswith("module."):
62
- new_state_dict[k[7:]] = v # Remove "module." prefix
63
  else:
64
- new_state_dict[k] = v
 
 
 
 
65
 
66
  model.load_state_dict(new_state_dict)
67
 
 
56
  checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
57
 
58
  # Strip the "module." prefix from the keys in the state_dict if they exist
59
+ # Clean up the state dictionary
60
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
61
  new_state_dict = {}
62
+ for k, v in state_dict.items():
63
  if k.startswith("module."):
64
+ new_key = k[7:] # Remove "module." prefix
65
  else:
66
+ new_key = k
67
+
68
+ # Check if the new_key exists in the model's state_dict, only add if it does
69
+ if new_key in model.state_dict():
70
+ new_state_dict[new_key] = v
71
 
72
  model.load_state_dict(new_state_dict)
73