Spaces:
Sleeping
Sleeping
artelabsuper
commited on
Commit
·
99ee6d2
1
Parent(s):
7d56262
cached models improve speed
Browse files
app.py
CHANGED
@@ -8,38 +8,41 @@ from matplotlib import colors
|
|
8 |
|
9 |
if not hasattr(st, 'paths'):
|
10 |
st.paths = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Load Model
|
13 |
# @title Load pretrained weights
|
14 |
|
15 |
-
|
16 |
-
best_model_annual_file_name = "best_model_annual.pth"
|
17 |
-
|
18 |
-
first_input_batch = torch.zeros(71, 9, 5, 48, 48)
|
19 |
-
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
|
20 |
-
daily_model = FPN(opt, first_input_batch, opt.win_size)
|
21 |
-
annual_model = SimpleNN(opt)
|
22 |
-
|
23 |
-
if torch.cuda.is_available():
|
24 |
-
daily_model = torch.nn.DataParallel(daily_model).cuda()
|
25 |
-
annual_model = torch.nn.DataParallel(annual_model).cuda()
|
26 |
-
daily_model = torch.nn.DataParallel(daily_model).cuda()
|
27 |
-
annual_model = torch.nn.DataParallel(annual_model).cuda()
|
28 |
-
else:
|
29 |
-
daily_model = torch.nn.DataParallel(daily_model).cpu()
|
30 |
-
annual_model = torch.nn.DataParallel(annual_model).cpu()
|
31 |
-
daily_model = torch.nn.DataParallel(daily_model).cpu()
|
32 |
-
annual_model = torch.nn.DataParallel(annual_model).cpu()
|
33 |
-
|
34 |
-
print('trying to resume previous saved models...')
|
35 |
-
state = resume(
|
36 |
-
os.path.join(opt.resume_path, best_model_daily_file_name),
|
37 |
-
model=daily_model, optimizer=None)
|
38 |
-
state = resume(
|
39 |
-
os.path.join(opt.resume_path, best_model_annual_file_name),
|
40 |
-
model=annual_model, optimizer=None)
|
41 |
-
daily_model = daily_model.eval()
|
42 |
-
annual_model = annual_model.eval()
|
43 |
|
44 |
st.title('Lombardia Sentinel 2 daily Crop Mapping')
|
45 |
st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
|
@@ -85,14 +88,14 @@ if sample_path is not None:
|
|
85 |
if torch.cuda.is_available():
|
86 |
x_dailies = x_dailies.cuda()
|
87 |
|
88 |
-
feat_daily, outs_daily = daily_model.forward(x_dailies)
|
89 |
# return to original size of batch and year
|
90 |
outs_daily = outs_daily.view(
|
91 |
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
|
92 |
feat_daily = feat_daily.view(
|
93 |
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
|
94 |
|
95 |
-
_, out_annual = annual_model.forward(feat_daily)
|
96 |
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
|
97 |
pred_annual = pred_annual.cpu().numpy()
|
98 |
# Remapping the labels
|
@@ -158,7 +161,7 @@ if st.paths is not None:
|
|
158 |
st.paths, index=st.paths.index('patch-pred-nn.tif'))
|
159 |
|
160 |
file_path = os.path.join(folder, file_picker)
|
161 |
-
print(file_path)
|
162 |
target, profile = read(file_path)
|
163 |
target = np.squeeze(target)
|
164 |
target = [classes_color_map[p] for p in target]
|
@@ -169,7 +172,7 @@ if st.paths is not None:
|
|
169 |
|
170 |
markdown_legend = ''
|
171 |
for c, l in zip(classes_color_map, labels_map):
|
172 |
-
print(colors.to_hex(c))
|
173 |
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
|
174 |
|
175 |
col1, col2 = st.columns(2)
|
|
|
8 |
|
9 |
if not hasattr(st, 'paths'):
|
10 |
st.paths = None
|
11 |
+
if not hasattr(st, 'daily_model'):
|
12 |
+
best_model_daily_file_name = "best_model_daily.pth"
|
13 |
+
best_model_annual_file_name = "best_model_annual.pth"
|
14 |
+
|
15 |
+
first_input_batch = torch.zeros(71, 9, 5, 48, 48)
|
16 |
+
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
|
17 |
+
st.daily_model = FPN(opt, first_input_batch, opt.win_size)
|
18 |
+
st.annual_model = SimpleNN(opt)
|
19 |
+
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
|
22 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
|
23 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
|
24 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
|
25 |
+
else:
|
26 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
|
27 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
|
28 |
+
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
|
29 |
+
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
|
30 |
+
|
31 |
+
print('trying to resume previous saved models...')
|
32 |
+
state = resume(
|
33 |
+
os.path.join(opt.resume_path, best_model_daily_file_name),
|
34 |
+
model=st.daily_model, optimizer=None)
|
35 |
+
state = resume(
|
36 |
+
os.path.join(opt.resume_path, best_model_annual_file_name),
|
37 |
+
model=st.annual_model, optimizer=None)
|
38 |
+
st.daily_model = st.daily_model.eval()
|
39 |
+
st.annual_model = st.annual_model.eval()
|
40 |
+
|
41 |
|
42 |
# Load Model
|
43 |
# @title Load pretrained weights
|
44 |
|
45 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
st.title('Lombardia Sentinel 2 daily Crop Mapping')
|
48 |
st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
|
|
|
88 |
if torch.cuda.is_available():
|
89 |
x_dailies = x_dailies.cuda()
|
90 |
|
91 |
+
feat_daily, outs_daily = st.daily_model.forward(x_dailies)
|
92 |
# return to original size of batch and year
|
93 |
outs_daily = outs_daily.view(
|
94 |
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
|
95 |
feat_daily = feat_daily.view(
|
96 |
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
|
97 |
|
98 |
+
_, out_annual = st.annual_model.forward(feat_daily)
|
99 |
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
|
100 |
pred_annual = pred_annual.cpu().numpy()
|
101 |
# Remapping the labels
|
|
|
161 |
st.paths, index=st.paths.index('patch-pred-nn.tif'))
|
162 |
|
163 |
file_path = os.path.join(folder, file_picker)
|
164 |
+
# print(file_path)
|
165 |
target, profile = read(file_path)
|
166 |
target = np.squeeze(target)
|
167 |
target = [classes_color_map[p] for p in target]
|
|
|
172 |
|
173 |
markdown_legend = ''
|
174 |
for c, l in zip(classes_color_map, labels_map):
|
175 |
+
# print(colors.to_hex(c))
|
176 |
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
|
177 |
|
178 |
col1, col2 = st.columns(2)
|