Spaces:
Running
Running
update 3
Browse files
app.py
CHANGED
@@ -56,12 +56,18 @@ checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve
|
|
56 |
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
|
57 |
|
58 |
# Strip the "module." prefix from the keys in the state_dict if they exist
|
|
|
|
|
59 |
new_state_dict = {}
|
60 |
-
for k, v in
|
61 |
if k.startswith("module."):
|
62 |
-
|
63 |
else:
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
model.load_state_dict(new_state_dict)
|
67 |
|
|
|
56 |
checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu'))
|
57 |
|
58 |
# Strip the "module." prefix from the keys in the state_dict if they exist
|
59 |
+
# Clean up the state dictionary
|
60 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
61 |
new_state_dict = {}
|
62 |
+
for k, v in state_dict.items():
|
63 |
if k.startswith("module."):
|
64 |
+
new_key = k[7:] # Remove "module." prefix
|
65 |
else:
|
66 |
+
new_key = k
|
67 |
+
|
68 |
+
# Check if the new_key exists in the model's state_dict, only add if it does
|
69 |
+
if new_key in model.state_dict():
|
70 |
+
new_state_dict[new_key] = v
|
71 |
|
72 |
model.load_state_dict(new_state_dict)
|
73 |
|