Spaces:
Sleeping
Sleeping
import streamlit as st | |
import zipfile | |
from utils import * | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
import streamlit.components.v1 as components | |
from matplotlib import colors | |
st.set_page_config(layout="wide") | |
def create_animation(images, pred_dates): | |
print('Creating composition of images...') | |
fps = 2 | |
fig_an, ax_an = plt.subplots() | |
plt.title("") | |
a = images[0] | |
im = ax_an.imshow(a, interpolation='none', aspect='auto', vmin=0, vmax=1) | |
title = ax_an.text(0.5, 0.85, "", bbox={'facecolor': 'w', 'alpha': 0.5, 'pad': 5}, | |
transform=ax_an.transAxes, ha="center") | |
def animate_func(idx): | |
title.set_text("date: " + pred_dates[idx]) | |
im.set_array(images[idx]) | |
return [im] | |
anima = animation.FuncAnimation(fig_an, animate_func, frames=len(images), interval=1000 / fps, blit=True, | |
repeat=False) | |
print('Done!') | |
return anima | |
def load_daily_preds_as_animations(pred_full_paths, pred_dates): | |
daily_preds = [] | |
for path in pred_full_paths: | |
img, _ = read(path) | |
img = np.squeeze(img) | |
img = [classes_color_map[p] for p in img] | |
daily_preds.append(img) | |
anima = create_animation(daily_preds, pred_dates) | |
return anima | |
def load_src_images_as_animations(img_paths, pred_dates): | |
imgs = [] | |
for path in img_paths: | |
img, _ = read(path) | |
# https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/composites/ | |
# IREA image: | |
# False colors (8,4,3): 2,blue-B3,green-B4,5,6,7,red-B8,11,12 | |
# Simple RGB (4, 3, 2): blue-B2,green-B3,red-B4,5,6,7,8,11,12 | |
rgb = img[[2, 1, 0], :, :] | |
rgb = np.moveaxis(rgb, 0, -1) | |
imgs.append(rgb/np.amax(rgb)) | |
anima = create_animation(imgs, pred_dates) | |
return anima | |
if not hasattr(st, 'paths'): | |
st.paths = None | |
if not hasattr(st, 'daily_model'): | |
best_model_daily_file_name = "best_model_daily.pth" | |
best_model_annual_file_name = "best_model_annual.pth" | |
first_input_batch = torch.zeros(71, 9, 5, 48, 48) | |
# first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:]) | |
st.daily_model = FPN(opt, first_input_batch, opt.win_size) | |
st.annual_model = SimpleNN(opt) | |
if torch.cuda.is_available(): | |
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda() | |
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda() | |
st.daily_model = torch.nn.DataParallel(st.daily_model).cuda() | |
st.annual_model = torch.nn.DataParallel(st.annual_model).cuda() | |
else: | |
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu() | |
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu() | |
st.daily_model = torch.nn.DataParallel(st.daily_model).cpu() | |
st.annual_model = torch.nn.DataParallel(st.annual_model).cpu() | |
print('trying to resume previous saved models...') | |
state = resume( | |
os.path.join(opt.resume_path, best_model_daily_file_name), | |
model=st.daily_model, optimizer=None) | |
state = resume( | |
os.path.join(opt.resume_path, best_model_annual_file_name), | |
model=st.annual_model, optimizer=None) | |
st.daily_model = st.daily_model.eval() | |
st.annual_model = st.annual_model.eval() | |
# Load Model | |
# @title Load pretrained weights | |
st.title('In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series') | |
st.markdown(""" Demo App for the model presented in the [paper](https://www.sciencedirect.com/science/article/pii/S0924271622003203): | |
``` | |
@article{gallo2022in_season, | |
title = {In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series}, | |
journal = {ISPRS Journal of Photogrammetry and Remote Sensing}, | |
volume = {195}, | |
pages = {335-352}, | |
year = {2023}, | |
issn = {0924-2716}, | |
doi = {https://doi.org/10.1016/j.isprsjprs.2022.12.005}, | |
url = {https://www.sciencedirect.com/science/article/pii/S0924271622003203}, | |
author = {Ignazio Gallo and Luigi Ranghetti and Nicola Landro and Riccardo {La Grassa} and Mirco Boschetti}, | |
} | |
``` | |
**NOTE: The demo doesn't work properly, we are working to fix the bugs!** | |
""") | |
file_uploaded = st.file_uploader( | |
"Upload", | |
type=["zip"], | |
accept_multiple_files=False, | |
) | |
sample_path = None | |
tileids = None | |
if file_uploaded is not None: | |
with zipfile.ZipFile(file_uploaded, "r") as z: | |
z.extractall("uploaded_samples") | |
sample_path = "uploaded_samples/" + file_uploaded.name[:-4] | |
st.markdown('or use a demo sample') | |
if st.button('sample 1'): | |
sample_path = 'demo_data/lombardia' | |
tileids = ['24'] | |
if st.button('sample 2'): | |
sample_path = 'demo_data/lombardia' | |
tileids = ['712'] | |
if st.button('sample 3'): | |
sample_path = 'demo_data/lombardia' | |
tileids = ['814'] | |
if st.button('sample 4'): | |
sample_path = 'demo_data/lombardia' | |
tileids = ['1509'] | |
paths = None | |
if sample_path is not None: | |
# st.markdown(f'elaborating {sample_path} ...') | |
validationdataset = SentinelDailyAnnualDatasetNoLabel( | |
sample_path, | |
opt.years, | |
opt.classes_path, | |
opt.sample_duration, | |
opt.win_size, | |
tileids=tileids) | |
validationdataloader = torch.utils.data.DataLoader( | |
validationdataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers) | |
st.markdown('Model prediction in progress ...') | |
out_dir = os.path.join(opt.result_path, "seg_maps") | |
if not os.path.exists(out_dir): | |
os.makedirs(out_dir) | |
for i, (x_dailies, dates, dirs_path) in enumerate(validationdataloader): | |
with torch.no_grad(): | |
# x_dailies, dates, dirs_path = next(iter(validationdataloader)) | |
# reshape merging the first two dimensions | |
x_dailies = x_dailies.view(-1, *x_dailies.shape[2:]) | |
if torch.cuda.is_available(): | |
x_dailies = x_dailies.cuda() | |
feat_daily, outs_daily = st.daily_model.forward(x_dailies) | |
# return to original size of batch and year | |
outs_daily = outs_daily.view( | |
opt.batch_size, opt.sample_duration, *outs_daily.shape[1:]) | |
feat_daily = feat_daily.view( | |
opt.batch_size, opt.sample_duration, *feat_daily.shape[1:]) | |
_, out_annual = st.annual_model.forward(feat_daily) | |
pred_annual = torch.argmax(out_annual, dim=1).squeeze(1) | |
pred_annual = pred_annual.cpu().numpy() | |
# Remapping the labels | |
pred_annual_nn = ids_to_labels( | |
validationdataloader, pred_annual).astype(numpy.uint8) | |
for batch in range(feat_daily.shape[0]): | |
# _, profile = read(os.path.join(dirs_path[batch], '20191230_MSAVI.tif')) # todo get the last image | |
_, tmp_path = get_patch_id(validationdataset.samples, 0) | |
dates = get_all_dates( | |
tmp_path, validationdataset.max_seq_length) | |
last_tif_path = os.path.join(tmp_path, dates[-1] + ".tif") | |
_, profile = read(last_tif_path) | |
profile["name"] = dirs_path[batch] | |
pth = dirs_path[batch].split(os.path.sep)[-3:] | |
full_pth_patch = os.path.join( | |
out_dir, pth[1] + '-' + pth[0], pth[2]) | |
if not os.path.exists(full_pth_patch): | |
os.makedirs(full_pth_patch) | |
full_pth_pred = os.path.join( | |
full_pth_patch, 'patch-pred-nn.tif') | |
profile.update({ | |
'nodata': None, | |
'dtype': 'uint8', | |
'count': 1}) | |
with rasterio.open(full_pth_pred, 'w', **profile) as dst: | |
dst.write_band(1, pred_annual_nn[batch]) | |
# patch_predictions = None | |
for ch in range(len(dates)): | |
soft_seg = outs_daily[batch, ch, :, :, :] | |
# transform probs into a hard segmentation | |
pred_daily = torch.argmax(soft_seg, dim=0) | |
pred_daily = pred_daily.cpu() | |
daily_pred = ids_to_labels( | |
validationdataloader, pred_daily).astype(numpy.uint8) | |
# if patch_predictions is None: | |
# patch_predictions = numpy.expand_dims(daily_pred, axis=0) | |
# else: | |
# patch_predictions = numpy.concatenate((patch_predictions, numpy.expand_dims(daily_pred, axis=0)), | |
# axis=0) | |
# save GT image in opt.root_path | |
full_pth_date = os.path.join( | |
full_pth_patch, dates[ch] + '-daily-pred.tif') | |
profile.update({ | |
'nodata': None, | |
'dtype': 'uint8', | |
'count': 1}) | |
with rasterio.open(full_pth_date, 'w', **profile) as dst: | |
dst.write_band(1, daily_pred) | |
st.markdown('End prediction') | |
# folder_out = "demo_data/results/seg_maps/example-lombardia/2" | |
folder_out = os.path.join("demo_data/results/seg_maps/"+opt.years[0]+"-lombardia/", tileids[0]) | |
st.paths = os.listdir(folder_out) | |
st.paths.sort() | |
if st.paths is not None: | |
# folder_out = os.path.join("demo_data/results/seg_maps/example-lombardia/", tileids[0]) | |
folder_src = os.path.join("demo_data/lombardia/", opt.years[0], tileids[0]) | |
st.markdown(""" | |
### Predictions | |
""") | |
# file_picker = st.selectbox("Select day predict (annual is patch-pred-nn.tif)", | |
# st.paths, index=st.paths.index('patch-pred-nn.tif')) | |
file_path = os.path.join(folder_out, 'patch-pred-nn.tif') | |
# print(file_path) | |
target, profile = read(file_path) | |
target = np.squeeze(target) | |
target = [classes_color_map[p] for p in target] | |
fig, ax = plt.subplots() | |
ax.imshow(target) | |
markdown_legend = '' | |
for c, l in zip(color_labels, labels_map): | |
# print(colors.to_hex(c)) | |
markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>' | |
col1, col2 = st.columns([2,1]) | |
with col1: | |
st.markdown("**Long-term (annual) prediction**") | |
st.pyplot(fig) | |
with col2: | |
st.markdown("**Legend**") | |
st.markdown(markdown_legend, unsafe_allow_html=True) | |
st.markdown("**Short-term (daily) prediction**") | |
img_full_paths = [os.path.join(folder_out, path) for path in st.paths if 'daily-pred' in path] | |
pred_dates = [path[:8] for path in st.paths if 'daily-pred' in path] | |
anim = load_daily_preds_as_animations(img_full_paths, pred_dates) | |
components.html(anim.to_jshtml(), height=600) | |
st.markdown("**Input time series**") | |
list_dir = os.listdir(folder_src) | |
list_dir.sort() | |
img_full_paths = [os.path.join(folder_src, f) for f in list_dir if f.endswith(".tif")] | |
pred_dates = [f[:8] for f in list_dir if f.endswith(".tif")] | |
anim_src = load_src_images_as_animations(img_full_paths, pred_dates) | |
components.html(anim_src.to_jshtml(), height=600) | |
st.markdown(""" | |
## Lombardia Dataset | |
You can download other patches from the original dataset created and published on | |
[Kaggle](https://www.kaggle.com/datasets/ignazio/sentinel2-crop-mapping) and used in our paper. | |
## How to build an input file for the Demo | |
Working in progress: to be defined ... | |
""") | |