DemoCropMapping / app.py
ignaziogallo
resolved numpy bug
6cb40a3
raw
history blame
7.44 kB
import os
import streamlit as st
import zipfile
import torch
from utils import *
import matplotlib.pyplot as plt
from matplotlib import colors
if not hasattr(st, 'paths'):
st.paths = None
if not hasattr(st, 'daily_model'):
best_model_daily_file_name = "best_model_daily.pth"
best_model_annual_file_name = "best_model_annual.pth"
first_input_batch = torch.zeros(71, 9, 5, 48, 48)
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
st.daily_model = FPN(opt, first_input_batch, opt.win_size)
st.annual_model = SimpleNN(opt)
if torch.cuda.is_available():
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
else:
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
print('trying to resume previous saved models...')
state = resume(
os.path.join(opt.resume_path, best_model_daily_file_name),
model=st.daily_model, optimizer=None)
state = resume(
os.path.join(opt.resume_path, best_model_annual_file_name),
model=st.annual_model, optimizer=None)
st.daily_model = st.daily_model.eval()
st.annual_model = st.annual_model.eval()
# Load Model
# @title Load pretrained weights
st.title('Lombardia Sentinel 2 daily Crop Mapping')
st.markdown('Using a daily FPN and giving a zip that contains 30 tiff with 7 channels, correctly named you can reach prediction of crop mapping og the area.')
file_uploaded = st.file_uploader(
"Upload",
type=["zip"],
accept_multiple_files=False,
)
sample_path = None
if file_uploaded is not None:
with zipfile.ZipFile(file_uploaded, "r") as z:
z.extractall("uploaded_samples")
sample_path = "uploaded_samples/" + file_uploaded.name[:-4]
st.markdown('or use a demo sample')
if st.button('sample 1'):
sample_path = 'demo_data/lombardia'
paths = None
if sample_path is not None:
st.markdown(f'elaborating {sample_path}...')
validationdataset = SentinelDailyAnnualDatasetNoLabel(
sample_path,
opt.years,
opt.classes_path,
opt.sample_duration,
opt.win_size,
tileids=None)
validationdataloader = torch.utils.data.DataLoader(
validationdataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)
st.markdown(f'predict in progress...')
out_dir = os.path.join(opt.result_path, "seg_maps")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
for i, (x_dailies, dates, dirs_path) in enumerate(validationdataloader):
with torch.no_grad():
# x_dailies, dates, dirs_path = next(iter(validationdataloader))
# reshape merging the first two dimensions
x_dailies = x_dailies.view(-1, *x_dailies.shape[2:])
if torch.cuda.is_available():
x_dailies = x_dailies.cuda()
feat_daily, outs_daily = st.daily_model.forward(x_dailies)
# return to original size of batch and year
outs_daily = outs_daily.view(
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
feat_daily = feat_daily.view(
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
_, out_annual = st.annual_model.forward(feat_daily)
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
pred_annual = pred_annual.cpu().numpy()
# Remapping the labels
pred_annual_nn = ids_to_labels(
validationdataloader, pred_annual).astype(numpy.uint8)
for batch in range(feat_daily.shape[0]):
# _, profile = read(os.path.join(dirs_path[batch], '20191230_MSAVI.tif')) # todo get the last image
_, tmp_path = get_patch_id(validationdataset.samples, 0)
dates = get_all_dates(
tmp_path, validationdataset.max_seq_length)
last_tif_path = os.path.join(tmp_path, dates[-1] + ".tif")
_, profile = read(last_tif_path)
profile["name"] = dirs_path[batch]
pth = dirs_path[batch].split(os.path.sep)[-3:]
full_pth_patch = os.path.join(
out_dir, pth[1] + '-' + pth[0], pth[2])
if not os.path.exists(full_pth_patch):
os.makedirs(full_pth_patch)
full_pth_pred = os.path.join(
full_pth_patch, 'patch-pred-nn.tif')
profile.update({
'nodata': None,
'dtype': 'uint8',
'count': 1})
with rasterio.open(full_pth_pred, 'w', **profile) as dst:
dst.write_band(1, pred_annual_nn[batch])
# patch_predictions = None
for ch in range(len(dates)):
soft_seg = outs_daily[batch, ch, :, :, :]
# transform probs into a hard segmentation
pred_daily = torch.argmax(soft_seg, dim=0)
pred_daily = pred_daily.cpu()
daily_pred = ids_to_labels(
validationdataloader, pred_daily).astype(numpy.uint8)
# if patch_predictions is None:
# patch_predictions = numpy.expand_dims(daily_pred, axis=0)
# else:
# patch_predictions = numpy.concatenate((patch_predictions, numpy.expand_dims(daily_pred, axis=0)),
# axis=0)
# save GT image in opt.root_path
full_pth_date = os.path.join(
full_pth_patch, dates[ch][batch] + f'-ch{ch}-b{batch}-daily-pred.tif')
profile.update({
'nodata': None,
'dtype': 'uint8',
'count': 1})
with rasterio.open(full_pth_date, 'w', **profile) as dst:
dst.write_band(1, daily_pred)
st.markdown('End prediction')
folder = "demo_data/results/seg_maps/example-lombardia/2"
st.paths = os.listdir(folder)
if st.paths is not None:
folder = "demo_data/results/seg_maps/example-lombardia/2"
file_picker = st.selectbox("Select day predict (annual is patch-pred-nn.tif)",
st.paths, index=st.paths.index('patch-pred-nn.tif'))
file_path = os.path.join(folder, file_picker)
# print(file_path)
target, profile = read(file_path)
target = np.squeeze(target)
target = [classes_color_map[p] for p in target]
fig, ax = plt.subplots()
ax.imshow(target)
markdown_legend = ''
for c, l in zip(color_labels, labels_map):
# print(colors.to_hex(c))
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
col1, col2 = st.columns(2)
with col1:
st.pyplot(fig)
with col2:
st.markdown(markdown_legend, unsafe_allow_html=True)