trying to figure out state dict
Browse files
app.py
CHANGED
@@ -296,9 +296,9 @@ def predict_and_analyze(model_name, num_channels, dim, input_channel, image):
|
|
296 |
# print(model_url)
|
297 |
|
298 |
loaded = torch.load(model_url, map_location='cpu', )
|
299 |
-
|
300 |
|
301 |
-
model.load_state_dict(loaded)
|
302 |
# print(model)
|
303 |
|
304 |
# model = EfficientNetPreTrained(config)
|
|
|
296 |
# print(model_url)
|
297 |
|
298 |
loaded = torch.load(model_url, map_location='cpu', )
|
299 |
+
print(loaded.keys())
|
300 |
|
301 |
+
model.load_state_dict(loaded['state_dict'])
|
302 |
# print(model)
|
303 |
|
304 |
# model = EfficientNetPreTrained(config)
|