Tej3 commited on
Commit
92f6568
·
1 Parent(s): 8b94e3d

Modifying App.py

Browse files
Files changed (1) hide show
  1. app.py +40 -33
app.py CHANGED
@@ -19,6 +19,26 @@ def cropping(img):
19
 
20
  return img
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
23
  print(DEVICE)
24
  CWD = "."
@@ -36,48 +56,35 @@ CKPT_FILE_NAMES = {
36
  }
37
  MODEL_CLASSES = {
38
  'Indoor': {
39
- 'Resnet_enc':enc_dec_model,
40
- 'Unet':ResNet18UNet,
41
- 'Densenet_enc':Densenet
42
  },
43
 
44
  'Outdoor': {
45
- 'Resnet_enc':enc_dec_model,
46
- 'Unet':UNetWithResnet50Encoder,
47
- 'Densenet_enc':Densenet
48
  },
49
-
50
  }
 
 
 
 
 
 
51
 
52
- def load_model(ckpt, model, optimizer=None):
53
- ckpt_dict = torch.load(ckpt, map_location='cpu')
54
- # keep backward compatibility
55
- if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict:
56
- state_dict = ckpt_dict
57
- else:
58
- state_dict = ckpt_dict['model']
59
- weights = {}
60
- for key, value in state_dict.items():
61
- if key.startswith('module.'):
62
- weights[key[len('module.'):]] = value
63
- else:
64
- weights[key] = value
65
-
66
- model.load_state_dict(weights)
67
-
68
- if optimizer is not None:
69
- optimizer_state = ckpt_dict['optimizer']
70
- optimizer.load_state_dict(optimizer_state)
71
 
72
 
73
  def predict(location, model_name, img):
74
- ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}"
75
- if location == 'nyu':
76
- max_depth = 10
77
- else:
78
- max_depth = 80
79
- model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE)
80
- load_model(ckpt_dir,model)
 
81
  # print(img.shape)
82
  # assert False
83
  if img.shape == (375,1242,3):
 
19
 
20
  return img
21
 
22
+ def load_model(ckpt, model, optimizer=None):
23
+ ckpt_dict = torch.load(ckpt, map_location='cpu')
24
+ # keep backward compatibility
25
+ if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict:
26
+ state_dict = ckpt_dict
27
+ else:
28
+ state_dict = ckpt_dict['model']
29
+ weights = {}
30
+ for key, value in state_dict.items():
31
+ if key.startswith('module.'):
32
+ weights[key[len('module.'):]] = value
33
+ else:
34
+ weights[key] = value
35
+
36
+ model.load_state_dict(weights)
37
+
38
+ if optimizer is not None:
39
+ optimizer_state = ckpt_dict['optimizer']
40
+ optimizer.load_state_dict(optimizer_state)
41
+
42
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
43
  print(DEVICE)
44
  CWD = "."
 
56
  }
57
  MODEL_CLASSES = {
58
  'Indoor': {
59
+ 'Resnet_enc':enc_dec_model(max_depth = 10),
60
+ 'Unet':ResNet18UNet(max_depth = 10),
61
+ 'Densenet_enc':Densenet(max_depth = 10)
62
  },
63
 
64
  'Outdoor': {
65
+ 'Resnet_enc':enc_dec_model(max_depth = 80),
66
+ 'Unet':UNetWithResnet50Encoder(max_depth = 80),
67
+ 'Densenet_enc':Densenet(max_depth = 80)
68
  },
 
69
  }
70
+ location_types = ['Indoor', 'Outdoor']
71
+ Models = ['Resnet_enc','Unet','Densenet_enc']
72
+ for location in location_types:
73
+ for model in Models:
74
+ ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model]}"
75
+ load_model(CKPT_FILE_NAMES[location][model], MODEL_CLASSES[location][model])
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  def predict(location, model_name, img):
80
+ # ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}"
81
+ # if location == 'nyu':
82
+ # max_depth = 10
83
+ # else:
84
+ # max_depth = 80
85
+ # model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE)
86
+ model = MODEL_CLASSES[location][model_name].to(DEVICE)
87
+ # load_model(ckpt_dir,model)
88
  # print(img.shape)
89
  # assert False
90
  if img.shape == (375,1242,3):