Spaces:
Build error
Build error
edits
Browse files
app.py
CHANGED
@@ -29,10 +29,18 @@ state = torch.load('fire.pth', map_location='cpu')
|
|
29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
31 |
net_sfm.load_state_dict(state['state_dict'])
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
|
|
|
|
|
|
34 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
35 |
-
net_imagenet.load_state_dict(state2['state_dict'])
|
36 |
|
37 |
# ---------------------------------------
|
38 |
transform = transforms.Compose([
|
|
|
29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
31 |
net_sfm.load_state_dict(state['state_dict'])
|
32 |
+
dim_red_params_dict = {}
|
33 |
+
for name, param in net_sfm.named_parameters():
|
34 |
+
if 'dim_reduction' in name:
|
35 |
+
dim_red_params_dict[name] = param
|
36 |
+
|
37 |
|
38 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
39 |
+
state2['net_params'] = state['net_params']
|
40 |
+
state2['state_dict'] += dim_red_params_dict
|
41 |
+
# state2['net_params'] =
|
42 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
43 |
+
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
|
44 |
|
45 |
# ---------------------------------------
|
46 |
transform = transforms.Compose([
|