artelabsuper commited on
Commit
99ee6d2
·
1 Parent(s): 7d56262

cached models improve speed

Browse files
Files changed (1) hide show
  1. app.py +35 -32
app.py CHANGED
@@ -8,38 +8,41 @@ from matplotlib import colors
8
 
9
  if not hasattr(st, 'paths'):
10
  st.paths = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load Model
13
  # @title Load pretrained weights
14
 
15
- best_model_daily_file_name = "best_model_daily.pth"
16
- best_model_annual_file_name = "best_model_annual.pth"
17
-
18
- first_input_batch = torch.zeros(71, 9, 5, 48, 48)
19
- # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
20
- daily_model = FPN(opt, first_input_batch, opt.win_size)
21
- annual_model = SimpleNN(opt)
22
-
23
- if torch.cuda.is_available():
24
- daily_model = torch.nn.DataParallel(daily_model).cuda()
25
- annual_model = torch.nn.DataParallel(annual_model).cuda()
26
- daily_model = torch.nn.DataParallel(daily_model).cuda()
27
- annual_model = torch.nn.DataParallel(annual_model).cuda()
28
- else:
29
- daily_model = torch.nn.DataParallel(daily_model).cpu()
30
- annual_model = torch.nn.DataParallel(annual_model).cpu()
31
- daily_model = torch.nn.DataParallel(daily_model).cpu()
32
- annual_model = torch.nn.DataParallel(annual_model).cpu()
33
-
34
- print('trying to resume previous saved models...')
35
- state = resume(
36
- os.path.join(opt.resume_path, best_model_daily_file_name),
37
- model=daily_model, optimizer=None)
38
- state = resume(
39
- os.path.join(opt.resume_path, best_model_annual_file_name),
40
- model=annual_model, optimizer=None)
41
- daily_model = daily_model.eval()
42
- annual_model = annual_model.eval()
43
 
44
  st.title('Lombardia Sentinel 2 daily Crop Mapping')
45
  st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
@@ -85,14 +88,14 @@ if sample_path is not None:
85
  if torch.cuda.is_available():
86
  x_dailies = x_dailies.cuda()
87
 
88
- feat_daily, outs_daily = daily_model.forward(x_dailies)
89
  # return to original size of batch and year
90
  outs_daily = outs_daily.view(
91
  opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
92
  feat_daily = feat_daily.view(
93
  opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
94
 
95
- _, out_annual = annual_model.forward(feat_daily)
96
  pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
97
  pred_annual = pred_annual.cpu().numpy()
98
  # Remapping the labels
@@ -158,7 +161,7 @@ if st.paths is not None:
158
  st.paths, index=st.paths.index('patch-pred-nn.tif'))
159
 
160
  file_path = os.path.join(folder, file_picker)
161
- print(file_path)
162
  target, profile = read(file_path)
163
  target = np.squeeze(target)
164
  target = [classes_color_map[p] for p in target]
@@ -169,7 +172,7 @@ if st.paths is not None:
169
 
170
  markdown_legend = ''
171
  for c, l in zip(classes_color_map, labels_map):
172
- print(colors.to_hex(c))
173
  markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
174
 
175
  col1, col2 = st.columns(2)
 
8
 
9
  if not hasattr(st, 'paths'):
10
  st.paths = None
11
+ if not hasattr(st, 'daily_model'):
12
+ best_model_daily_file_name = "best_model_daily.pth"
13
+ best_model_annual_file_name = "best_model_annual.pth"
14
+
15
+ first_input_batch = torch.zeros(71, 9, 5, 48, 48)
16
+ # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
17
+ st.daily_model = FPN(opt, first_input_batch, opt.win_size)
18
+ st.annual_model = SimpleNN(opt)
19
+
20
+ if torch.cuda.is_available():
21
+ st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
22
+ st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
23
+ st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
24
+ st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
25
+ else:
26
+ st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
27
+ st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
28
+ st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
29
+ st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
30
+
31
+ print('trying to resume previous saved models...')
32
+ state = resume(
33
+ os.path.join(opt.resume_path, best_model_daily_file_name),
34
+ model=st.daily_model, optimizer=None)
35
+ state = resume(
36
+ os.path.join(opt.resume_path, best_model_annual_file_name),
37
+ model=st.annual_model, optimizer=None)
38
+ st.daily_model = st.daily_model.eval()
39
+ st.annual_model = st.annual_model.eval()
40
+
41
 
42
  # Load Model
43
  # @title Load pretrained weights
44
 
45
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  st.title('Lombardia Sentinel 2 daily Crop Mapping')
48
  st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
 
88
  if torch.cuda.is_available():
89
  x_dailies = x_dailies.cuda()
90
 
91
+ feat_daily, outs_daily = st.daily_model.forward(x_dailies)
92
  # return to original size of batch and year
93
  outs_daily = outs_daily.view(
94
  opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
95
  feat_daily = feat_daily.view(
96
  opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
97
 
98
+ _, out_annual = st.annual_model.forward(feat_daily)
99
  pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
100
  pred_annual = pred_annual.cpu().numpy()
101
  # Remapping the labels
 
161
  st.paths, index=st.paths.index('patch-pred-nn.tif'))
162
 
163
  file_path = os.path.join(folder, file_picker)
164
+ # print(file_path)
165
  target, profile = read(file_path)
166
  target = np.squeeze(target)
167
  target = [classes_color_map[p] for p in target]
 
172
 
173
  markdown_legend = ''
174
  for c, l in zip(classes_color_map, labels_map):
175
+ # print(colors.to_hex(c))
176
  markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
177
 
178
  col1, col2 = st.columns(2)