DemoCropMapping / app.py
ignaziogallo
added new samples
a563c94
raw
history blame
11.6 kB
import streamlit as st
import zipfile
from utils import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import streamlit.components.v1 as components
from matplotlib import colors
st.set_page_config(layout="wide")
def create_animation(images, pred_dates):
print('Creating composition of images...')
fps = 2
fig_an, ax_an = plt.subplots()
plt.title("")
a = images[0]
im = ax_an.imshow(a, interpolation='none', aspect='auto', vmin=0, vmax=1)
title = ax_an.text(0.5, 0.85, "", bbox={'facecolor': 'w', 'alpha': 0.5, 'pad': 5},
transform=ax_an.transAxes, ha="center")
def animate_func(idx):
title.set_text("date: " + pred_dates[idx])
im.set_array(images[idx])
return [im]
anima = animation.FuncAnimation(fig_an, animate_func, frames=len(images), interval=1000 / fps, blit=True,
repeat=False)
print('Done!')
return anima
def load_daily_preds_as_animations(pred_full_paths, pred_dates):
daily_preds = []
for path in pred_full_paths:
img, _ = read(path)
img = np.squeeze(img)
img = [classes_color_map[p] for p in img]
daily_preds.append(img)
anima = create_animation(daily_preds, pred_dates)
return anima
def load_src_images_as_animations(img_paths, pred_dates):
imgs = []
for path in img_paths:
img, _ = read(path)
# https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/composites/
# IREA image:
# False colors (8,4,3): 2,blue-B3,green-B4,5,6,7,red-B8,11,12
# Simple RGB (4, 3, 2): blue-B2,green-B3,red-B4,5,6,7,8,11,12
rgb = img[[2, 1, 0], :, :]
rgb = np.moveaxis(rgb, 0, -1)
imgs.append(rgb/np.amax(rgb))
anima = create_animation(imgs, pred_dates)
return anima
if not hasattr(st, 'paths'):
st.paths = None
if not hasattr(st, 'daily_model'):
best_model_daily_file_name = "best_model_daily.pth"
best_model_annual_file_name = "best_model_annual.pth"
first_input_batch = torch.zeros(71, 9, 5, 48, 48)
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
st.daily_model = FPN(opt, first_input_batch, opt.win_size)
st.annual_model = SimpleNN(opt)
if torch.cuda.is_available():
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
else:
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
print('trying to resume previous saved models...')
state = resume(
os.path.join(opt.resume_path, best_model_daily_file_name),
model=st.daily_model, optimizer=None)
state = resume(
os.path.join(opt.resume_path, best_model_annual_file_name),
model=st.annual_model, optimizer=None)
st.daily_model = st.daily_model.eval()
st.annual_model = st.annual_model.eval()
# Load Model
# @title Load pretrained weights
st.title('In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series')
st.markdown(""" Demo App for the model presented in the [paper](https://www.sciencedirect.com/science/article/pii/S0924271622003203):
```
@article{gallo2022in_season,
title = {In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series},
journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
volume = {195},
pages = {335-352},
year = {2023},
issn = {0924-2716},
doi = {https://doi.org/10.1016/j.isprsjprs.2022.12.005},
url = {https://www.sciencedirect.com/science/article/pii/S0924271622003203},
author = {Ignazio Gallo and Luigi Ranghetti and Nicola Landro and Riccardo {La Grassa} and Mirco Boschetti},
}
```
**NOTE: The demo doesn't work properly, we are working to fix the bugs!**
""")
file_uploaded = st.file_uploader(
"Upload",
type=["zip"],
accept_multiple_files=False,
)
sample_path = None
tileids = None
if file_uploaded is not None:
with zipfile.ZipFile(file_uploaded, "r") as z:
z.extractall("uploaded_samples")
sample_path = "uploaded_samples/" + file_uploaded.name[:-4]
st.markdown('or use a demo sample')
if st.button('sample 1'):
sample_path = 'demo_data/lombardia'
tileids = ['24']
if st.button('sample 2'):
sample_path = 'demo_data/lombardia'
tileids = ['712']
if st.button('sample 3'):
sample_path = 'demo_data/lombardia'
tileids = ['814']
if st.button('sample 4'):
sample_path = 'demo_data/lombardia'
tileids = ['1509']
paths = None
if sample_path is not None:
# st.markdown(f'elaborating {sample_path} ...')
validationdataset = SentinelDailyAnnualDatasetNoLabel(
sample_path,
opt.years,
opt.classes_path,
opt.sample_duration,
opt.win_size,
tileids=tileids)
validationdataloader = torch.utils.data.DataLoader(
validationdataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)
st.markdown('Model prediction in progress ...')
out_dir = os.path.join(opt.result_path, "seg_maps")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
for i, (x_dailies, dates, dirs_path) in enumerate(validationdataloader):
with torch.no_grad():
# x_dailies, dates, dirs_path = next(iter(validationdataloader))
# reshape merging the first two dimensions
x_dailies = x_dailies.view(-1, *x_dailies.shape[2:])
if torch.cuda.is_available():
x_dailies = x_dailies.cuda()
feat_daily, outs_daily = st.daily_model.forward(x_dailies)
# return to original size of batch and year
outs_daily = outs_daily.view(
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
feat_daily = feat_daily.view(
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])
_, out_annual = st.annual_model.forward(feat_daily)
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
pred_annual = pred_annual.cpu().numpy()
# Remapping the labels
pred_annual_nn = ids_to_labels(
validationdataloader, pred_annual).astype(numpy.uint8)
for batch in range(feat_daily.shape[0]):
# _, profile = read(os.path.join(dirs_path[batch], '20191230_MSAVI.tif')) # todo get the last image
_, tmp_path = get_patch_id(validationdataset.samples, 0)
dates = get_all_dates(
tmp_path, validationdataset.max_seq_length)
last_tif_path = os.path.join(tmp_path, dates[-1] + ".tif")
_, profile = read(last_tif_path)
profile["name"] = dirs_path[batch]
pth = dirs_path[batch].split(os.path.sep)[-3:]
full_pth_patch = os.path.join(
out_dir, pth[1] + '-' + pth[0], pth[2])
if not os.path.exists(full_pth_patch):
os.makedirs(full_pth_patch)
full_pth_pred = os.path.join(
full_pth_patch, 'patch-pred-nn.tif')
profile.update({
'nodata': None,
'dtype': 'uint8',
'count': 1})
with rasterio.open(full_pth_pred, 'w', **profile) as dst:
dst.write_band(1, pred_annual_nn[batch])
# patch_predictions = None
for ch in range(len(dates)):
soft_seg = outs_daily[batch, ch, :, :, :]
# transform probs into a hard segmentation
pred_daily = torch.argmax(soft_seg, dim=0)
pred_daily = pred_daily.cpu()
daily_pred = ids_to_labels(
validationdataloader, pred_daily).astype(numpy.uint8)
# if patch_predictions is None:
# patch_predictions = numpy.expand_dims(daily_pred, axis=0)
# else:
# patch_predictions = numpy.concatenate((patch_predictions, numpy.expand_dims(daily_pred, axis=0)),
# axis=0)
# save GT image in opt.root_path
full_pth_date = os.path.join(
full_pth_patch, dates[ch] + '-daily-pred.tif')
profile.update({
'nodata': None,
'dtype': 'uint8',
'count': 1})
with rasterio.open(full_pth_date, 'w', **profile) as dst:
dst.write_band(1, daily_pred)
st.markdown('End prediction')
# folder_out = "demo_data/results/seg_maps/example-lombardia/2"
folder_out = os.path.join("demo_data/results/seg_maps/"+opt.years[0]+"-lombardia/", tileids[0])
st.paths = os.listdir(folder_out)
st.paths.sort()
if st.paths is not None:
# folder_out = os.path.join("demo_data/results/seg_maps/example-lombardia/", tileids[0])
folder_src = os.path.join("demo_data/lombardia/", opt.years[0], tileids[0])
st.markdown("""
### Predictions
""")
# file_picker = st.selectbox("Select day predict (annual is patch-pred-nn.tif)",
# st.paths, index=st.paths.index('patch-pred-nn.tif'))
file_path = os.path.join(folder_out, 'patch-pred-nn.tif')
# print(file_path)
target, profile = read(file_path)
target = np.squeeze(target)
target = [classes_color_map[p] for p in target]
fig, ax = plt.subplots()
ax.imshow(target)
markdown_legend = ''
for c, l in zip(color_labels, labels_map):
# print(colors.to_hex(c))
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'
col1, col2 = st.columns([2,1])
with col1:
st.markdown("**Long-term (annual) prediction**")
st.pyplot(fig)
with col2:
st.markdown("**Legend**")
st.markdown(markdown_legend, unsafe_allow_html=True)
st.markdown("**Short-term (daily) prediction**")
img_full_paths = [os.path.join(folder_out, path) for path in st.paths if 'daily-pred' in path]
pred_dates = [path[:8] for path in st.paths if 'daily-pred' in path]
anim = load_daily_preds_as_animations(img_full_paths, pred_dates)
components.html(anim.to_jshtml(), height=600)
st.markdown("**Input time series**")
list_dir = os.listdir(folder_src)
list_dir.sort()
img_full_paths = [os.path.join(folder_src, f) for f in list_dir if f.endswith(".tif")]
pred_dates = [f[:8] for f in list_dir if f.endswith(".tif")]
anim_src = load_src_images_as_animations(img_full_paths, pred_dates)
components.html(anim_src.to_jshtml(), height=600)
st.markdown("""
## Lombardia Dataset
You can download other patches from the original dataset created and published on
[Kaggle](https://www.kaggle.com/datasets/ignazio/sentinel2-crop-mapping) and used in our paper.
## How to build an input file for the Demo
Working in progress: to be defined ...
""")