Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- UnCRtainTS/.gitignore +6 -0
- UnCRtainTS/Dockerfile +35 -0
- UnCRtainTS/README.md +109 -0
- UnCRtainTS/architecture.png +0 -0
- UnCRtainTS/data/__init__.py +0 -0
- UnCRtainTS/data/dataLoader.py +633 -0
- UnCRtainTS/environment.yaml +13 -0
- UnCRtainTS/model/.gitignore +136 -0
- UnCRtainTS/model/checkpoint/diffcr_bs32_epoch17/model.pth.tar +3 -0
- UnCRtainTS/model/checkpoint/monotemporalL2/conf.json +56 -0
- UnCRtainTS/model/checkpoint/monotemporalL2/model.pth.tar +3 -0
- UnCRtainTS/model/ensemble_reconstruct.py +180 -0
- UnCRtainTS/model/inference/diffcr_bs32_epoch17/conf.json +73 -0
- UnCRtainTS/model/inference/monotemporalL2/conf.json +75 -0
- UnCRtainTS/model/inference/monotemporalL2/test_metrics.json +11 -0
- UnCRtainTS/model/parse_args.py +95 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/conf.json +74 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model.pth.tar +3 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_11.pth.tar +3 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_36.pth.tar +3 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_10_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_11_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_12_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_13_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_14_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_15_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_16_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_17_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_18_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_19_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_1_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_20_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_21_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_22_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_23_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_24_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_25_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_26_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_27_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_28_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_29_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_2_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_30_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_31_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_32_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_33_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_34_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_35_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_36_metrics.json +11 -0
- UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_3_metrics.json +11 -0
UnCRtainTS/.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.npy
|
2 |
+
logs
|
3 |
+
model/inference
|
4 |
+
model/checkpoint
|
5 |
+
model/results
|
6 |
+
*_pycache_*
|
UnCRtainTS/Dockerfile
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime
|
2 |
+
|
3 |
+
# install dependencies
|
4 |
+
#RUN pip install functorch
|
5 |
+
# '' may actually no longer be needed from torch 1.13 on
|
6 |
+
RUN pip install cupy-cuda112
|
7 |
+
RUN conda install -c conda-forge cupy
|
8 |
+
#RUN conda install pytorch torchvision cudatoolkit=11.7 -c pytorch
|
9 |
+
RUN pip install opencv-python
|
10 |
+
# RUN conda install -c conda-forge opencv
|
11 |
+
RUN pip install scipy rasterio natsort matplotlib scikit-image tqdm pandas
|
12 |
+
RUN pip install Pillow dominate visdom tensorboard
|
13 |
+
RUN pip install kornia torchgeometry torchmetrics torchnet segmentation-models-pytorch
|
14 |
+
RUN pip install s2cloudless
|
15 |
+
# see: https://github.com/sentinel-hub/sentinel2-cloud-detector/issues/17
|
16 |
+
RUN pip install numpy==1.21.6
|
17 |
+
|
18 |
+
RUN apt-get -y update
|
19 |
+
RUN apt-get -y install git
|
20 |
+
RUN pip install -U 'git+https://github.com/facebookresearch/fvcore'
|
21 |
+
|
22 |
+
# just in case some last-minute changes are needed
|
23 |
+
RUN apt-get install nano
|
24 |
+
|
25 |
+
# bake repository into dockerfile
|
26 |
+
RUN mkdir -p ./data
|
27 |
+
RUN mkdir -p ./model
|
28 |
+
RUN mkdir -p ./util
|
29 |
+
|
30 |
+
ADD data ./data
|
31 |
+
ADD model ./model
|
32 |
+
ADD util ./util
|
33 |
+
ADD . ./
|
34 |
+
|
35 |
+
WORKDIR /workspace/model
|
UnCRtainTS/README.md
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UnCRtainTS: Uncertainty Quantification for Cloud Removal in Optical Satellite Time Series
|
2 |
+
|
3 |
+

|
4 |
+
>
|
5 |
+
> _This is the official repository for UnCRtainTS, a network for multi-temporal cloud removal in satellite data combining a novel attention-based architecture, and a formulation for multivariate uncertainty prediction. These two components combined set a new state-of-the-art performance in terms of image reconstruction on two public cloud removal datasets. Additionally, we show how the well-calibrated predicted uncertainties enable a precise control of the reconstruction quality._
|
6 |
+
----
|
7 |
+
This repository contains code accompanying the paper
|
8 |
+
> P. Ebel, V. Garnot, M. Schmitt, J. Wegner and X. X. Zhu. UnCRtainTS: Uncertainty Quantification for Cloud Removal in Optical Satellite Time Series. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, 2023.
|
9 |
+
|
10 |
+
For additional information:
|
11 |
+
|
12 |
+
* The publication is available in the [CVPRW Proceedings](https://openaccess.thecvf.com/content/CVPR2023W/EarthVision/papers/Ebel_UnCRtainTS_Uncertainty_Quantification_for_Cloud_Removal_in_Optical_Satellite_Time_CVPRW_2023_paper.pdf).
|
13 |
+
* The SEN12MS-CR-TS data set is accessible at the MediaTUM page [here](https://mediatum.ub.tum.de/1639953) (train split) and [here](https://mediatum.ub.tum.de/1659251) (test split).
|
14 |
+
* You can find additional information on this and related projects on the associated [cloud removal projects page](https://patrickTUM.github.io/cloud_removal/).
|
15 |
+
* For any further questions, please reach out to me here or via the credentials on my [website](https://pwjebel.com).
|
16 |
+
---
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
### Dataset
|
20 |
+
|
21 |
+
You can easily download the multi-temporal SEN12MS-CR-TS (and, optionally, the mono-temporal SEN12MS-CR) dataset via the shell script in [`./util/dl_data.sh`](https://github.com/PatrickTUM/UnCRtainTS/blob/main/util/dl_data.sh). Alternatively, you may download the SEN12MS-CR-TS data set (or parts of it) via the MediaTUM website [here](https://mediatum.ub.tum.de/1639953) (train split) and [here](https://mediatum.ub.tum.de/1659251) (test split), with further instructions provided in the dataset's own [dedicated repository](https://github.com/PatrickTUM/SEN12MS-CR-TS#dataset).
|
22 |
+
|
23 |
+
### Code
|
24 |
+
Clone this repository via `git clone https://github.com/PatrickTUM/UnCRtainTS.git`.
|
25 |
+
|
26 |
+
and set up the Python environment via
|
27 |
+
|
28 |
+
```bash
|
29 |
+
conda env create --file environment.yaml
|
30 |
+
conda activate uncrtaints
|
31 |
+
```
|
32 |
+
|
33 |
+
Alternatively, you may install all that's needed via
|
34 |
+
```bash
|
35 |
+
pip install -r requirements.txt
|
36 |
+
```
|
37 |
+
or by building a Docker image of `Dockerfile` and deploying a container.
|
38 |
+
|
39 |
+
The code is written in Python 3 and uses PyTorch $\geq$ 2.0. It is strongly recommended to run the code with CUDA and GPU support. The code has been developed and deployed in Ubuntu 20 LTS and should be able to run in any comparable OS.
|
40 |
+
|
41 |
+
---
|
42 |
+
|
43 |
+
## Usage
|
44 |
+
### Dataset
|
45 |
+
If you already have your own model in place or wish to build one on the SEN12MS-CR-TS data loader for training and testing, the data loader can be used as a stand-alone script as demonstrated in `./standalone_dataloader.py`. This only requires the files `./data/dataLoader.py` (the actual data loader) and `./util/detect_cloudshadow.py` (if this type of cloud detector is chosen).
|
46 |
+
|
47 |
+
For using the dataset as a stand-alone with your own model, loading multi-temporal multi-modal data from SEN12MS-CR-TS is as simple as
|
48 |
+
|
49 |
+
``` python
|
50 |
+
import torch
|
51 |
+
from data.dataLoader import SEN12MSCRTS
|
52 |
+
dir_SEN12MSCRTS = '/path/to/your/SEN12MSCRTS'
|
53 |
+
sen12mscrts = SEN12MSCRTS(dir_SEN12MSCRTS, split='all', region='all', n_input_samples=3)
|
54 |
+
dataloader = torch.utils.data.DataLoader(sen12mscrts)
|
55 |
+
|
56 |
+
for pdx, samples in enumerate(dataloader): print(samples['input'].keys())
|
57 |
+
```
|
58 |
+
|
59 |
+
and, likewise, if you wish to (pre-)train on the mono-temporal multi-modal SEN12MS-CR dataset:
|
60 |
+
|
61 |
+
``` python
|
62 |
+
import torch
|
63 |
+
from data.dataLoader import SEN12MSCR
|
64 |
+
dir_SEN12MSCR = '/path/to/your/SEN12MSCR'
|
65 |
+
sen12mscr = SEN12MSCR(dir_SEN12MSCR, split='all', region='all')
|
66 |
+
dataloader = torch.utils.data.DataLoader(sen12mscr)
|
67 |
+
|
68 |
+
for pdx, samples in enumerate(dataloader): print(samples['input'].keys())
|
69 |
+
```
|
70 |
+
|
71 |
+
Note that computing cloud masks on the fly, depending on the choice of cloud detection, may slow down data loading. For greater efficiency, files of pre-computed cloud coverage statistics can be
|
72 |
+
downloaded [here](https://u.pcloud.link/publink/show?code=kZXdbk0ZaAHNV2a5ofbB9UW4xCyCT0YFYAFk) or pre-computed via `./util/pre_compute_data_samples.py`, and then loaded with the `--precomputed /path/to/files/` flag.
|
73 |
+
|
74 |
+
### Basic Commands
|
75 |
+
You can train a new model via
|
76 |
+
```bash
|
77 |
+
cd ./UnCRtainTS/model
|
78 |
+
python train_reconstruct.py --experiment_name my_first_experiment --root1 path/to/SEN12MSCRtrain --root2 path/to/SEN12MSCRtest --root3 path/to/SEN12MSCR --model uncrtaints --input_t 3 --region all --epochs 20 --lr 0.001 --batch_size 4 --gamma 1.0 --scale_by 10.0 --trained_checkp "" --loss MGNLL --covmode diag --var_nonLinearity softplus --display_step 10 --use_sar --block_type mbconv --n_head 16 --device cuda --res_dir ./results --rdm_seed 1
|
79 |
+
```
|
80 |
+
and you can test a (pre-)trained model via
|
81 |
+
```bash
|
82 |
+
python test_reconstruct.py --experiment_name my_first_experiment -root1 path/to/SEN12MSCRtrain --root2 path/to/SEN12MSCRtest --root3 path/to/SEN12MSCR --input_t 3 --region all --export_every 1 --res_dir ./inference --weight_folder ./results
|
83 |
+
```
|
84 |
+
|
85 |
+
For a list and description of all flags, please see the parser file `./model/parse_args.py`. To perform inference with pre-trained models, [here](https://u.pcloud.link/publink/show?code=kZsdbk0Z5Y2Y2UEm48XLwOvwSVlL8R2L3daV)'s where you can find the checkpoints. Every checkpoint is accompanied by a json file, documenting the flags set during training and expected to reproduce the model's behavior at test time. If pointing towards the exported configurations upon call, the correct settings get loaded automatically in the test script. Finally, following the exporting of model predictions via `test_reconstruct.py`, multiple models' outputs can be ensembled via `ensemble_reconstruct.py`, to obtain estimates of epistemic uncertainty.
|
86 |
+
|
87 |
+
---
|
88 |
+
|
89 |
+
|
90 |
+
## References
|
91 |
+
|
92 |
+
If you use this code, our models or data set for your research, please cite [this](https://openaccess.thecvf.com/content/CVPR2023W/EarthVision/papers/Ebel_UnCRtainTS_Uncertainty_Quantification_for_Cloud_Removal_in_Optical_Satellite_Time_CVPRW_2023_paper.pdf) publication:
|
93 |
+
```bibtex
|
94 |
+
@inproceedings{UnCRtainTS,
|
95 |
+
title = {{UnCRtainTS: Uncertainty Quantification for Cloud Removal in Optical Satellite Time Series}},
|
96 |
+
author = {Ebel, Patrick and Garnot, Vivien Sainte Fare and Schmitt, Michael and Wegner, Jan and Zhu, Xiao Xiang},
|
97 |
+
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops},
|
98 |
+
year = {2023},
|
99 |
+
organization = {IEEE},
|
100 |
+
url = {"https://openaccess.thecvf.com/content/CVPR2023W/EarthVision/papers/Ebel_UnCRtainTS_Uncertainty_Quantification_for_Cloud_Removal_in_Optical_Satellite_Time_CVPRW_2023_paper.pdf"}
|
101 |
+
}
|
102 |
+
```
|
103 |
+
You may also be interested in our related works, which you can discover on the accompanying [cloud removal projects website](https://patrickTUM.github.io/cloud_removal/).
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
## Credits
|
108 |
+
|
109 |
+
This code was originally based on the [UTAE](https://github.com/VSainteuf/utae-paps) and the [SEN12MS-CR-TS](https://github.com/PatrickTUM/SEN12MS-CR-TS) repositories. Thanks for making your code publicly available! We hope this repository will equally contribute to the development of future exciting work.
|
UnCRtainTS/architecture.png
ADDED
![]() |
UnCRtainTS/data/__init__.py
ADDED
File without changes
|
UnCRtainTS/data/dataLoader.py
ADDED
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import warnings
|
4 |
+
import numpy as np
|
5 |
+
from natsort import natsorted
|
6 |
+
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
to_date = lambda string: datetime.strptime(string, "%Y-%m-%d")
|
10 |
+
S1_LAUNCH = to_date("2014-04-03")
|
11 |
+
|
12 |
+
# s2cloudless: see https://github.com/sentinel-hub/sentinel2-cloud-detector
|
13 |
+
from s2cloudless import S2PixelCloudDetector
|
14 |
+
|
15 |
+
import rasterio
|
16 |
+
from rasterio.merge import merge
|
17 |
+
from scipy.ndimage import gaussian_filter
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
# import sys
|
20 |
+
|
21 |
+
# sys.path.append(".")
|
22 |
+
from util.detect_cloudshadow import get_cloud_mask, get_shadow_mask
|
23 |
+
|
24 |
+
|
25 |
+
# utility functions used in the dataloaders of SEN12MS-CR and SEN12MS-CR-TS
|
26 |
+
def read_tif(path_IMG):
|
27 |
+
tif = rasterio.open(path_IMG)
|
28 |
+
return tif
|
29 |
+
|
30 |
+
|
31 |
+
def read_img(tif):
|
32 |
+
return tif.read().astype(np.float32)
|
33 |
+
|
34 |
+
|
35 |
+
def rescale(img, oldMin, oldMax):
|
36 |
+
oldRange = oldMax - oldMin
|
37 |
+
img = (img - oldMin) / oldRange
|
38 |
+
return img
|
39 |
+
|
40 |
+
|
41 |
+
def process_MS(img, method):
|
42 |
+
if method == "default":
|
43 |
+
intensity_min, intensity_max = (
|
44 |
+
0,
|
45 |
+
10000,
|
46 |
+
) # define a reasonable range of MS intensities
|
47 |
+
img = np.clip(
|
48 |
+
img, intensity_min, intensity_max
|
49 |
+
) # intensity clipping to a global unified MS intensity range
|
50 |
+
img = rescale(
|
51 |
+
img, intensity_min, intensity_max
|
52 |
+
) # project to [0,1], preserve global intensities (across patches), gets mapped to [-1,+1] in wrapper
|
53 |
+
if method == "resnet":
|
54 |
+
intensity_min, intensity_max = (
|
55 |
+
0,
|
56 |
+
10000,
|
57 |
+
) # define a reasonable range of MS intensities
|
58 |
+
img = np.clip(
|
59 |
+
img, intensity_min, intensity_max
|
60 |
+
) # intensity clipping to a global unified MS intensity range
|
61 |
+
img /= 2000 # project to [0,5], preserve global intensities (across patches)
|
62 |
+
img = np.nan_to_num(img)
|
63 |
+
return img
|
64 |
+
|
65 |
+
|
66 |
+
def process_SAR(img, method):
|
67 |
+
if method == "default":
|
68 |
+
dB_min, dB_max = -25, 0 # define a reasonable range of SAR dB
|
69 |
+
img = np.clip(
|
70 |
+
img, dB_min, dB_max
|
71 |
+
) # intensity clipping to a global unified SAR dB range
|
72 |
+
img = rescale(
|
73 |
+
img, dB_min, dB_max
|
74 |
+
) # project to [0,1], preserve global intensities (across patches), gets mapped to [-1,+1] in wrapper
|
75 |
+
if method == "resnet":
|
76 |
+
# project SAR to [0, 2] range
|
77 |
+
dB_min, dB_max = [-25.0, -32.5], [0, 0]
|
78 |
+
img = np.concatenate(
|
79 |
+
[
|
80 |
+
(
|
81 |
+
2
|
82 |
+
* (np.clip(img[0], dB_min[0], dB_max[0]) - dB_min[0])
|
83 |
+
/ (dB_max[0] - dB_min[0])
|
84 |
+
)[None, ...],
|
85 |
+
(
|
86 |
+
2
|
87 |
+
* (np.clip(img[1], dB_min[1], dB_max[1]) - dB_min[1])
|
88 |
+
/ (dB_max[1] - dB_min[1])
|
89 |
+
)[None, ...],
|
90 |
+
],
|
91 |
+
axis=0,
|
92 |
+
)
|
93 |
+
img = np.nan_to_num(img)
|
94 |
+
return img
|
95 |
+
|
96 |
+
|
97 |
+
def get_cloud_cloudshadow_mask(img, cloud_threshold=0.2):
|
98 |
+
cloud_mask = get_cloud_mask(img, cloud_threshold, binarize=True)
|
99 |
+
shadow_mask = get_shadow_mask(img)
|
100 |
+
|
101 |
+
# encode clouds and shadows as segmentation masks
|
102 |
+
cloud_cloudshadow_mask = np.zeros_like(cloud_mask)
|
103 |
+
cloud_cloudshadow_mask[shadow_mask < 0] = -1
|
104 |
+
cloud_cloudshadow_mask[cloud_mask > 0] = 1
|
105 |
+
|
106 |
+
# label clouds and shadows
|
107 |
+
cloud_cloudshadow_mask[cloud_cloudshadow_mask != 0] = 1
|
108 |
+
return cloud_cloudshadow_mask
|
109 |
+
|
110 |
+
|
111 |
+
# recursively apply function to nested dictionary
|
112 |
+
def iterdict(dictionary, fct):
|
113 |
+
for k, v in dictionary.items():
|
114 |
+
if isinstance(v, dict):
|
115 |
+
dictionary[k] = iterdict(v, fct)
|
116 |
+
else:
|
117 |
+
dictionary[k] = fct(v)
|
118 |
+
return dictionary
|
119 |
+
|
120 |
+
|
121 |
+
def get_cloud_map(img, detector, instance=None):
|
122 |
+
# get cloud masks
|
123 |
+
img = np.clip(img, 0, 10000)
|
124 |
+
mask = np.ones((img.shape[-1], img.shape[-1]))
|
125 |
+
# note: if your model may suffer from dark pixel artifacts,
|
126 |
+
# you may consider adjusting these filtering parameters
|
127 |
+
if not (img.mean() < 1e-5 and img.std() < 1e-5):
|
128 |
+
if detector == "cloud_cloudshadow_mask":
|
129 |
+
threshold = 0.2 # set to e.g. 0.2 or 0.4
|
130 |
+
mask = get_cloud_cloudshadow_mask(img, threshold)
|
131 |
+
elif detector == "s2cloudless_map":
|
132 |
+
threshold = 0.5
|
133 |
+
mask = instance.get_cloud_probability_maps(
|
134 |
+
np.moveaxis(img / 10000, 0, -1)[None, ...]
|
135 |
+
)[0, ...]
|
136 |
+
mask[mask < threshold] = 0
|
137 |
+
mask = gaussian_filter(mask, sigma=2)
|
138 |
+
elif detector == "s2cloudless_mask":
|
139 |
+
mask = instance.get_cloud_masks(np.moveaxis(img / 10000, 0, -1)[None, ...])[
|
140 |
+
0, ...
|
141 |
+
]
|
142 |
+
else:
|
143 |
+
mask = np.ones((img.shape[-1], img.shape[-1]))
|
144 |
+
warnings.warn(f"Method {detector} not yet implemented!")
|
145 |
+
else:
|
146 |
+
warnings.warn(f"Encountered a blank sample, defaulting to cloudy mask.")
|
147 |
+
return mask.astype(np.float32)
|
148 |
+
|
149 |
+
|
150 |
+
# function to fetch paired data, which may differ in modalities or dates
|
151 |
+
def get_pairedS1(patch_list, root_dir, mod=None, time=None):
|
152 |
+
paired_list = []
|
153 |
+
for patch in patch_list:
|
154 |
+
seed, roi, modality, time_number, fname = patch.split("/")
|
155 |
+
time = time_number if time is None else time # unless overwriting, ...
|
156 |
+
mod = (
|
157 |
+
modality if mod is None else mod
|
158 |
+
) # keep the patch list's original time and modality
|
159 |
+
n_patch = fname.split("patch_")[-1].split(".tif")[0]
|
160 |
+
paired_dir = os.path.join(seed, roi, mod.upper(), str(time))
|
161 |
+
candidates = os.path.join(
|
162 |
+
root_dir,
|
163 |
+
paired_dir,
|
164 |
+
f"{mod}_{seed}_{roi}_ImgNo_{time}_*_patch_{n_patch}.tif",
|
165 |
+
)
|
166 |
+
paired_list.append(
|
167 |
+
os.path.join(paired_dir, os.path.basename(glob.glob(candidates)[0]))
|
168 |
+
)
|
169 |
+
return paired_list
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
""" SEN12MSCR data loader class, inherits from torch.utils.data.Dataset
|
177 |
+
|
178 |
+
IN:
|
179 |
+
root: str, path to your copy of the SEN12MS-CR-TS data set
|
180 |
+
split: str, in [all | train | val | test]
|
181 |
+
region: str, [all | africa | america | asiaEast | asiaWest | europa]
|
182 |
+
cloud_masks: str, type of cloud mask detector to run on optical data, in []
|
183 |
+
sample_type: str, [generic | cloudy_cloudfree]
|
184 |
+
n_input_samples: int, number of input samples in time series
|
185 |
+
rescale_method: str, [default | resnet]
|
186 |
+
|
187 |
+
OUT:
|
188 |
+
data_loader: SEN12MSCRTS instance, implements an iterator that can be traversed via __getitem__(pdx),
|
189 |
+
which returns the pdx-th dictionary of patch-samples (whose structure depends on sample_type)
|
190 |
+
"""
|
191 |
+
|
192 |
+
|
193 |
+
class SEN12MSCR(Dataset):
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
root,
|
197 |
+
split="all",
|
198 |
+
region="all",
|
199 |
+
cloud_masks="s2cloudless_mask",
|
200 |
+
sample_type="pretrain",
|
201 |
+
rescale_method="default",
|
202 |
+
):
|
203 |
+
self.root_dir = root # set root directory which contains all ROI
|
204 |
+
self.region = region # region according to which the ROI are selected
|
205 |
+
if self.region != "all":
|
206 |
+
raise NotImplementedError # TODO: currently only supporting 'all'
|
207 |
+
self.ROI = {
|
208 |
+
"ROIs1158": ["106"],
|
209 |
+
"ROIs1868": [
|
210 |
+
"17",
|
211 |
+
"36",
|
212 |
+
"56",
|
213 |
+
"73",
|
214 |
+
"85",
|
215 |
+
"100",
|
216 |
+
"114",
|
217 |
+
"119",
|
218 |
+
"121",
|
219 |
+
"126",
|
220 |
+
"127",
|
221 |
+
"139",
|
222 |
+
"142",
|
223 |
+
"143",
|
224 |
+
],
|
225 |
+
"ROIs1970": [
|
226 |
+
"20",
|
227 |
+
"21",
|
228 |
+
"35",
|
229 |
+
"40",
|
230 |
+
"57",
|
231 |
+
"65",
|
232 |
+
"71",
|
233 |
+
"82",
|
234 |
+
"83",
|
235 |
+
"91",
|
236 |
+
"112",
|
237 |
+
"116",
|
238 |
+
"119",
|
239 |
+
"128",
|
240 |
+
"132",
|
241 |
+
"133",
|
242 |
+
"135",
|
243 |
+
"139",
|
244 |
+
"142",
|
245 |
+
"144",
|
246 |
+
"149",
|
247 |
+
],
|
248 |
+
"ROIs2017": [
|
249 |
+
"8",
|
250 |
+
"22",
|
251 |
+
"25",
|
252 |
+
"32",
|
253 |
+
"49",
|
254 |
+
"61",
|
255 |
+
"63",
|
256 |
+
"69",
|
257 |
+
"75",
|
258 |
+
"103",
|
259 |
+
"108",
|
260 |
+
"115",
|
261 |
+
"116",
|
262 |
+
"117",
|
263 |
+
"130",
|
264 |
+
"140",
|
265 |
+
"146",
|
266 |
+
],
|
267 |
+
}
|
268 |
+
|
269 |
+
# define splits conform with SEN12MS-CR-TS
|
270 |
+
self.splits = {}
|
271 |
+
self.splits["train"] = [
|
272 |
+
"ROIs1970_fall_s1/s1_3",
|
273 |
+
"ROIs1970_fall_s1/s1_22",
|
274 |
+
"ROIs1970_fall_s1/s1_148",
|
275 |
+
"ROIs1970_fall_s1/s1_107",
|
276 |
+
"ROIs1970_fall_s1/s1_1",
|
277 |
+
"ROIs1970_fall_s1/s1_114",
|
278 |
+
"ROIs1970_fall_s1/s1_135",
|
279 |
+
"ROIs1970_fall_s1/s1_40",
|
280 |
+
"ROIs1970_fall_s1/s1_42",
|
281 |
+
"ROIs1970_fall_s1/s1_31",
|
282 |
+
"ROIs1970_fall_s1/s1_149",
|
283 |
+
"ROIs1970_fall_s1/s1_64",
|
284 |
+
"ROIs1970_fall_s1/s1_28",
|
285 |
+
"ROIs1970_fall_s1/s1_144",
|
286 |
+
"ROIs1970_fall_s1/s1_57",
|
287 |
+
"ROIs1970_fall_s1/s1_35",
|
288 |
+
"ROIs1970_fall_s1/s1_133",
|
289 |
+
"ROIs1970_fall_s1/s1_30",
|
290 |
+
"ROIs1970_fall_s1/s1_134",
|
291 |
+
"ROIs1970_fall_s1/s1_141",
|
292 |
+
"ROIs1970_fall_s1/s1_112",
|
293 |
+
"ROIs1970_fall_s1/s1_116",
|
294 |
+
"ROIs1970_fall_s1/s1_37",
|
295 |
+
"ROIs1970_fall_s1/s1_26",
|
296 |
+
"ROIs1970_fall_s1/s1_77",
|
297 |
+
"ROIs1970_fall_s1/s1_100",
|
298 |
+
"ROIs1970_fall_s1/s1_83",
|
299 |
+
"ROIs1970_fall_s1/s1_71",
|
300 |
+
"ROIs1970_fall_s1/s1_93",
|
301 |
+
"ROIs1970_fall_s1/s1_119",
|
302 |
+
"ROIs1970_fall_s1/s1_104",
|
303 |
+
"ROIs1970_fall_s1/s1_136",
|
304 |
+
"ROIs1970_fall_s1/s1_6",
|
305 |
+
"ROIs1970_fall_s1/s1_41",
|
306 |
+
"ROIs1970_fall_s1/s1_125",
|
307 |
+
"ROIs1970_fall_s1/s1_91",
|
308 |
+
"ROIs1970_fall_s1/s1_131",
|
309 |
+
"ROIs1970_fall_s1/s1_120",
|
310 |
+
"ROIs1970_fall_s1/s1_110",
|
311 |
+
"ROIs1970_fall_s1/s1_19",
|
312 |
+
"ROIs1970_fall_s1/s1_14",
|
313 |
+
"ROIs1970_fall_s1/s1_81",
|
314 |
+
"ROIs1970_fall_s1/s1_39",
|
315 |
+
"ROIs1970_fall_s1/s1_109",
|
316 |
+
"ROIs1970_fall_s1/s1_33",
|
317 |
+
"ROIs1970_fall_s1/s1_88",
|
318 |
+
"ROIs1970_fall_s1/s1_11",
|
319 |
+
"ROIs1970_fall_s1/s1_128",
|
320 |
+
"ROIs1970_fall_s1/s1_142",
|
321 |
+
"ROIs1970_fall_s1/s1_122",
|
322 |
+
"ROIs1970_fall_s1/s1_4",
|
323 |
+
"ROIs1970_fall_s1/s1_27",
|
324 |
+
"ROIs1970_fall_s1/s1_147",
|
325 |
+
"ROIs1970_fall_s1/s1_85",
|
326 |
+
"ROIs1970_fall_s1/s1_82",
|
327 |
+
"ROIs1970_fall_s1/s1_105",
|
328 |
+
"ROIs1158_spring_s1/s1_9",
|
329 |
+
"ROIs1158_spring_s1/s1_1",
|
330 |
+
"ROIs1158_spring_s1/s1_124",
|
331 |
+
"ROIs1158_spring_s1/s1_40",
|
332 |
+
"ROIs1158_spring_s1/s1_101",
|
333 |
+
"ROIs1158_spring_s1/s1_21",
|
334 |
+
"ROIs1158_spring_s1/s1_134",
|
335 |
+
"ROIs1158_spring_s1/s1_145",
|
336 |
+
"ROIs1158_spring_s1/s1_141",
|
337 |
+
"ROIs1158_spring_s1/s1_66",
|
338 |
+
"ROIs1158_spring_s1/s1_8",
|
339 |
+
"ROIs1158_spring_s1/s1_26",
|
340 |
+
"ROIs1158_spring_s1/s1_77",
|
341 |
+
"ROIs1158_spring_s1/s1_113",
|
342 |
+
"ROIs1158_spring_s1/s1_100",
|
343 |
+
"ROIs1158_spring_s1/s1_117",
|
344 |
+
"ROIs1158_spring_s1/s1_119",
|
345 |
+
"ROIs1158_spring_s1/s1_6",
|
346 |
+
"ROIs1158_spring_s1/s1_58",
|
347 |
+
"ROIs1158_spring_s1/s1_120",
|
348 |
+
"ROIs1158_spring_s1/s1_110",
|
349 |
+
"ROIs1158_spring_s1/s1_126",
|
350 |
+
"ROIs1158_spring_s1/s1_115",
|
351 |
+
"ROIs1158_spring_s1/s1_121",
|
352 |
+
"ROIs1158_spring_s1/s1_39",
|
353 |
+
"ROIs1158_spring_s1/s1_109",
|
354 |
+
"ROIs1158_spring_s1/s1_63",
|
355 |
+
"ROIs1158_spring_s1/s1_75",
|
356 |
+
"ROIs1158_spring_s1/s1_132",
|
357 |
+
"ROIs1158_spring_s1/s1_128",
|
358 |
+
"ROIs1158_spring_s1/s1_142",
|
359 |
+
"ROIs1158_spring_s1/s1_15",
|
360 |
+
"ROIs1158_spring_s1/s1_45",
|
361 |
+
"ROIs1158_spring_s1/s1_97",
|
362 |
+
"ROIs1158_spring_s1/s1_147",
|
363 |
+
"ROIs1868_summer_s1/s1_90",
|
364 |
+
"ROIs1868_summer_s1/s1_87",
|
365 |
+
"ROIs1868_summer_s1/s1_25",
|
366 |
+
"ROIs1868_summer_s1/s1_124",
|
367 |
+
"ROIs1868_summer_s1/s1_114",
|
368 |
+
"ROIs1868_summer_s1/s1_135",
|
369 |
+
"ROIs1868_summer_s1/s1_40",
|
370 |
+
"ROIs1868_summer_s1/s1_101",
|
371 |
+
"ROIs1868_summer_s1/s1_42",
|
372 |
+
"ROIs1868_summer_s1/s1_31",
|
373 |
+
"ROIs1868_summer_s1/s1_36",
|
374 |
+
"ROIs1868_summer_s1/s1_139",
|
375 |
+
"ROIs1868_summer_s1/s1_56",
|
376 |
+
"ROIs1868_summer_s1/s1_133",
|
377 |
+
"ROIs1868_summer_s1/s1_55",
|
378 |
+
"ROIs1868_summer_s1/s1_43",
|
379 |
+
"ROIs1868_summer_s1/s1_113",
|
380 |
+
"ROIs1868_summer_s1/s1_76",
|
381 |
+
"ROIs1868_summer_s1/s1_123",
|
382 |
+
"ROIs1868_summer_s1/s1_143",
|
383 |
+
"ROIs1868_summer_s1/s1_93",
|
384 |
+
"ROIs1868_summer_s1/s1_125",
|
385 |
+
"ROIs1868_summer_s1/s1_89",
|
386 |
+
"ROIs1868_summer_s1/s1_120",
|
387 |
+
"ROIs1868_summer_s1/s1_126",
|
388 |
+
"ROIs1868_summer_s1/s1_72",
|
389 |
+
"ROIs1868_summer_s1/s1_115",
|
390 |
+
"ROIs1868_summer_s1/s1_121",
|
391 |
+
"ROIs1868_summer_s1/s1_146",
|
392 |
+
"ROIs1868_summer_s1/s1_140",
|
393 |
+
"ROIs1868_summer_s1/s1_95",
|
394 |
+
"ROIs1868_summer_s1/s1_102",
|
395 |
+
"ROIs1868_summer_s1/s1_7",
|
396 |
+
"ROIs1868_summer_s1/s1_11",
|
397 |
+
"ROIs1868_summer_s1/s1_132",
|
398 |
+
"ROIs1868_summer_s1/s1_15",
|
399 |
+
"ROIs1868_summer_s1/s1_137",
|
400 |
+
"ROIs1868_summer_s1/s1_4",
|
401 |
+
"ROIs1868_summer_s1/s1_27",
|
402 |
+
"ROIs1868_summer_s1/s1_147",
|
403 |
+
"ROIs1868_summer_s1/s1_86",
|
404 |
+
"ROIs1868_summer_s1/s1_47",
|
405 |
+
"ROIs2017_winter_s1/s1_68",
|
406 |
+
"ROIs2017_winter_s1/s1_25",
|
407 |
+
"ROIs2017_winter_s1/s1_62",
|
408 |
+
"ROIs2017_winter_s1/s1_135",
|
409 |
+
"ROIs2017_winter_s1/s1_42",
|
410 |
+
"ROIs2017_winter_s1/s1_64",
|
411 |
+
"ROIs2017_winter_s1/s1_21",
|
412 |
+
"ROIs2017_winter_s1/s1_55",
|
413 |
+
"ROIs2017_winter_s1/s1_112",
|
414 |
+
"ROIs2017_winter_s1/s1_116",
|
415 |
+
"ROIs2017_winter_s1/s1_8",
|
416 |
+
"ROIs2017_winter_s1/s1_59",
|
417 |
+
"ROIs2017_winter_s1/s1_49",
|
418 |
+
"ROIs2017_winter_s1/s1_104",
|
419 |
+
"ROIs2017_winter_s1/s1_81",
|
420 |
+
"ROIs2017_winter_s1/s1_146",
|
421 |
+
"ROIs2017_winter_s1/s1_75",
|
422 |
+
"ROIs2017_winter_s1/s1_94",
|
423 |
+
"ROIs2017_winter_s1/s1_102",
|
424 |
+
"ROIs2017_winter_s1/s1_61",
|
425 |
+
"ROIs2017_winter_s1/s1_47",
|
426 |
+
"ROIs1868_summer_s1/s1_100", # note: this ROI is also used for testing in SEN12MS-CR-TS. If you wish to combine both datasets, please comment out this line
|
427 |
+
]
|
428 |
+
self.splits["val"] = [
|
429 |
+
"ROIs2017_winter_s1/s1_22",
|
430 |
+
"ROIs1868_summer_s1/s1_19",
|
431 |
+
"ROIs1970_fall_s1/s1_65",
|
432 |
+
"ROIs1158_spring_s1/s1_17",
|
433 |
+
"ROIs2017_winter_s1/s1_107",
|
434 |
+
"ROIs1868_summer_s1/s1_80",
|
435 |
+
"ROIs1868_summer_s1/s1_127",
|
436 |
+
"ROIs2017_winter_s1/s1_130",
|
437 |
+
"ROIs1868_summer_s1/s1_17",
|
438 |
+
"ROIs2017_winter_s1/s1_84",
|
439 |
+
]
|
440 |
+
self.splits["test"] = [
|
441 |
+
"ROIs1158_spring_s1/s1_106",
|
442 |
+
"ROIs1158_spring_s1/s1_123",
|
443 |
+
"ROIs1158_spring_s1/s1_140",
|
444 |
+
"ROIs1158_spring_s1/s1_31",
|
445 |
+
"ROIs1158_spring_s1/s1_44",
|
446 |
+
"ROIs1868_summer_s1/s1_119",
|
447 |
+
"ROIs1868_summer_s1/s1_73",
|
448 |
+
"ROIs1970_fall_s1/s1_139",
|
449 |
+
"ROIs2017_winter_s1/s1_108",
|
450 |
+
"ROIs2017_winter_s1/s1_63",
|
451 |
+
]
|
452 |
+
|
453 |
+
self.splits["all"] = (
|
454 |
+
self.splits["train"] + self.splits["test"] + self.splits["val"]
|
455 |
+
)
|
456 |
+
self.split = split
|
457 |
+
|
458 |
+
assert split in [
|
459 |
+
"all",
|
460 |
+
"train",
|
461 |
+
"val",
|
462 |
+
"test",
|
463 |
+
], "Input dataset must be either assigned as all, train, test, or val!"
|
464 |
+
assert sample_type in ["pretrain"], "Input data must be pretrain!"
|
465 |
+
assert cloud_masks in [
|
466 |
+
None,
|
467 |
+
"cloud_cloudshadow_mask",
|
468 |
+
"s2cloudless_map",
|
469 |
+
"s2cloudless_mask",
|
470 |
+
], "Unknown cloud mask type!"
|
471 |
+
|
472 |
+
self.modalities = ["S1", "S2"]
|
473 |
+
self.cloud_masks = cloud_masks # e.g. 'cloud_cloudshadow_mask', 's2cloudless_map', 's2cloudless_mask'
|
474 |
+
self.sample_type = sample_type # e.g. 'pretrain'
|
475 |
+
|
476 |
+
self.time_points = range(1)
|
477 |
+
self.n_input_t = 1 # specifies the number of samples, if only part of the time series is used as an input
|
478 |
+
|
479 |
+
if self.cloud_masks in ["s2cloudless_map", "s2cloudless_mask"]:
|
480 |
+
self.cloud_detector = S2PixelCloudDetector(
|
481 |
+
threshold=0.4, all_bands=True, average_over=4, dilation_size=2
|
482 |
+
)
|
483 |
+
else:
|
484 |
+
self.cloud_detector = None
|
485 |
+
|
486 |
+
self.paths = self.get_paths()
|
487 |
+
self.n_samples = len(self.paths)
|
488 |
+
|
489 |
+
# raise a warning if no data has been found
|
490 |
+
if not self.n_samples:
|
491 |
+
self.throw_warn()
|
492 |
+
|
493 |
+
self.method = rescale_method
|
494 |
+
|
495 |
+
# indexes all patches contained in the current data split
|
496 |
+
def get_paths(
|
497 |
+
self,
|
498 |
+
): # assuming for the same ROI+num, the patch numbers are the same
|
499 |
+
print(f"\nProcessing paths for {self.split} split of region {self.region}")
|
500 |
+
|
501 |
+
paths = []
|
502 |
+
seeds_S1 = natsorted(
|
503 |
+
[s1dir for s1dir in os.listdir(self.root_dir) if "_s1" in s1dir]
|
504 |
+
)
|
505 |
+
for seed in seeds_S1:
|
506 |
+
rois_S1 = natsorted(os.listdir(os.path.join(self.root_dir, seed)))
|
507 |
+
for roi in rois_S1:
|
508 |
+
roi_dir = os.path.join(self.root_dir, seed, roi)
|
509 |
+
paths_S1 = natsorted(
|
510 |
+
[os.path.join(roi_dir, s1patch) for s1patch in os.listdir(roi_dir)]
|
511 |
+
)
|
512 |
+
paths_S2 = [
|
513 |
+
patch.replace("/s1", "/s2").replace("_s1", "_s2")
|
514 |
+
for patch in paths_S1
|
515 |
+
]
|
516 |
+
paths_S2_cloudy = [
|
517 |
+
patch.replace("/s1", "/s2_cloudy").replace("_s1", "_s2_cloudy")
|
518 |
+
for patch in paths_S1
|
519 |
+
]
|
520 |
+
|
521 |
+
for pdx, _ in enumerate(paths_S1):
|
522 |
+
# omit patches that are potentially unpaired
|
523 |
+
if not all(
|
524 |
+
[
|
525 |
+
os.path.isfile(paths_S1[pdx]),
|
526 |
+
os.path.isfile(paths_S2[pdx]),
|
527 |
+
os.path.isfile(paths_S2_cloudy[pdx]),
|
528 |
+
]
|
529 |
+
):
|
530 |
+
continue
|
531 |
+
# don't add patch if not belonging to the selected split
|
532 |
+
if not any(
|
533 |
+
[
|
534 |
+
split_roi in paths_S1[pdx]
|
535 |
+
for split_roi in self.splits[self.split]
|
536 |
+
]
|
537 |
+
):
|
538 |
+
continue
|
539 |
+
sample = {
|
540 |
+
"S1": paths_S1[pdx],
|
541 |
+
"S2": paths_S2[pdx],
|
542 |
+
"S2_cloudy": paths_S2_cloudy[pdx],
|
543 |
+
}
|
544 |
+
paths.append(sample)
|
545 |
+
return paths
|
546 |
+
|
547 |
+
def __getitem__(self, pdx): # get the triplet of patch with ID pdx
|
548 |
+
s1_tif = read_tif(self.paths[pdx]["S1"])
|
549 |
+
s2_tif = read_tif(self.paths[pdx]["S2"])
|
550 |
+
s2_cloudy_tif = read_tif(self.paths[pdx]["S2_cloudy"])
|
551 |
+
coord = list(s2_tif.bounds)
|
552 |
+
s1 = process_SAR(read_img(s1_tif), self.method)
|
553 |
+
s2 = read_img(s2_tif) # note: pre-processing happens after cloud detection
|
554 |
+
s2_cloudy = read_img(
|
555 |
+
s2_cloudy_tif
|
556 |
+
) # note: pre-processing happens after cloud detection
|
557 |
+
mask = (
|
558 |
+
None
|
559 |
+
if not self.cloud_masks
|
560 |
+
else get_cloud_map(s2_cloudy, self.cloud_masks, self.cloud_detector)
|
561 |
+
)
|
562 |
+
|
563 |
+
sample = {
|
564 |
+
"input": {
|
565 |
+
"S1": s1,
|
566 |
+
"S2": process_MS(s2_cloudy, self.method),
|
567 |
+
"masks": mask,
|
568 |
+
"coverage": np.mean(mask),
|
569 |
+
"S1 path": os.path.join(self.root_dir, self.paths[pdx]["S1"]),
|
570 |
+
"S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2_cloudy"]),
|
571 |
+
"coord": coord,
|
572 |
+
},
|
573 |
+
"target": {
|
574 |
+
"S2": process_MS(s2, self.method),
|
575 |
+
"S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2"]),
|
576 |
+
"coord": coord,
|
577 |
+
},
|
578 |
+
}
|
579 |
+
return sample
|
580 |
+
|
581 |
+
def throw_warn(self):
|
582 |
+
warnings.warn(
|
583 |
+
"""No data samples found! Please use the following directory structure:
|
584 |
+
|
585 |
+
path/to/your/SEN12MSCR/directory:
|
586 |
+
├───ROIs1158_spring_s1
|
587 |
+
| ├─s1_1
|
588 |
+
| | |...
|
589 |
+
| | ├─ROIs1158_spring_s1_1_p407.tif
|
590 |
+
| | |...
|
591 |
+
| ...
|
592 |
+
├───ROIs1158_spring_s2
|
593 |
+
| ├─s2_1
|
594 |
+
| | |...
|
595 |
+
| | ├─ROIs1158_spring_s2_1_p407.tif
|
596 |
+
| | |...
|
597 |
+
| ...
|
598 |
+
├───ROIs1158_spring_s2_cloudy
|
599 |
+
| ├─s2_cloudy_1
|
600 |
+
| | |...
|
601 |
+
| | ├─ROIs1158_spring_s2_cloudy_1_p407.tif
|
602 |
+
| | |...
|
603 |
+
| ...
|
604 |
+
...
|
605 |
+
|
606 |
+
Note: Please arrange the dataset in a format as e.g. provided by the script dl_data.sh.
|
607 |
+
"""
|
608 |
+
)
|
609 |
+
|
610 |
+
def __len__(self):
|
611 |
+
# length of generated list
|
612 |
+
return self.n_samples
|
613 |
+
|
614 |
+
|
615 |
+
if __name__ == "__main__":
|
616 |
+
dataset = SEN12MSCR(
|
617 |
+
root="data2/SEN12MSCR",
|
618 |
+
split="all",
|
619 |
+
region="all",
|
620 |
+
cloud_masks="s2cloudless_mask",
|
621 |
+
sample_type="pretrain",
|
622 |
+
rescale_method="default",
|
623 |
+
)
|
624 |
+
for each in dataset:
|
625 |
+
print(f"{each['input']['S1'].shape}")
|
626 |
+
print(f"{each['input']['S2'].shape}")
|
627 |
+
print(f"{each['input']['masks'].shape}")
|
628 |
+
print(f"{each['target']['S2'].shape}")
|
629 |
+
# (2, 256, 256)
|
630 |
+
# (13, 256, 256)
|
631 |
+
# (256, 256)
|
632 |
+
# (13, 256, 256)
|
633 |
+
break
|
UnCRtainTS/environment.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: uncrtaints
|
2 |
+
channels:
|
3 |
+
- nvidia
|
4 |
+
- pytorch
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- nvidia::cudatoolkit=11.7
|
8 |
+
- python
|
9 |
+
- pip=20.3
|
10 |
+
- pytorch::pytorch=2.0.0
|
11 |
+
- numpy
|
12 |
+
- pip:
|
13 |
+
|
UnCRtainTS/model/.gitignore
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
todo.txt
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
*.swp
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
.idea/
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
pip-wheel-metadata/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
|
30 |
+
|
31 |
+
.DS_Store
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
.python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
99 |
+
__pypackages__/
|
100 |
+
|
101 |
+
# Celery stuff
|
102 |
+
celerybeat-schedule
|
103 |
+
celerybeat.pid
|
104 |
+
|
105 |
+
# SageMath parsed files
|
106 |
+
*.sage.py
|
107 |
+
|
108 |
+
# Environments
|
109 |
+
.env
|
110 |
+
.venv
|
111 |
+
env/
|
112 |
+
venv/
|
113 |
+
ENV/
|
114 |
+
env.bak/
|
115 |
+
venv.bak/
|
116 |
+
|
117 |
+
# Spyder project settings
|
118 |
+
.spyderproject
|
119 |
+
.spyproject
|
120 |
+
|
121 |
+
# Rope project settings
|
122 |
+
.ropeproject
|
123 |
+
|
124 |
+
# mkdocs documentation
|
125 |
+
/site
|
126 |
+
|
127 |
+
# mypy
|
128 |
+
.mypy_cache/
|
129 |
+
.dmypy.json
|
130 |
+
dmypy.json
|
131 |
+
|
132 |
+
# Pyre type checker
|
133 |
+
.pyre/
|
134 |
+
|
135 |
+
# ignore particular folders and directories
|
136 |
+
./util/precomputed
|
UnCRtainTS/model/checkpoint/diffcr_bs32_epoch17/model.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:733e42830f19c26569427dec26ecb63bf98a19a3d48d76df041d30290e3a28c6
|
3 |
+
size 213833726
|
UnCRtainTS/model/checkpoint/monotemporalL2/conf.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model": "uncrtaints",
|
3 |
+
"encoder_widths": [
|
4 |
+
128
|
5 |
+
],
|
6 |
+
"decoder_widths": [
|
7 |
+
128,
|
8 |
+
128,
|
9 |
+
128,
|
10 |
+
128,
|
11 |
+
128
|
12 |
+
],
|
13 |
+
"out_conv": [
|
14 |
+
13
|
15 |
+
],
|
16 |
+
"mean_nonLinearity": true,
|
17 |
+
"var_nonLinearity": "softplus",
|
18 |
+
"use_sar": true,
|
19 |
+
"agg_mode": "att_group",
|
20 |
+
"encoder_norm": "group",
|
21 |
+
"decoder_norm": "batch",
|
22 |
+
"n_head": 1,
|
23 |
+
"d_model": 256,
|
24 |
+
"use_v": false,
|
25 |
+
"positional_encoding": true,
|
26 |
+
"d_k": 4,
|
27 |
+
"res_dir": "./results",
|
28 |
+
"experiment_name": "monotemporalL2",
|
29 |
+
"device": "cuda",
|
30 |
+
"display_step": 10,
|
31 |
+
"batch_size": 4,
|
32 |
+
"lr": 0.001,
|
33 |
+
"gamma": 0.8,
|
34 |
+
"ref_date": "2014-04-03",
|
35 |
+
"pad_value": 0,
|
36 |
+
"padding_mode": "reflect",
|
37 |
+
"val_every": 1,
|
38 |
+
"val_after": 0,
|
39 |
+
"pretrain": true,
|
40 |
+
"input_t": 1,
|
41 |
+
"sample_type": "pretrain",
|
42 |
+
"vary_samples": true,
|
43 |
+
"min_cov": 0.0,
|
44 |
+
"max_cov": 1.0,
|
45 |
+
"region": "all",
|
46 |
+
"max_samples": 1000000000,
|
47 |
+
"input_size": 256,
|
48 |
+
"plot_every": -1,
|
49 |
+
"loss": "l2",
|
50 |
+
"covmode": "diag",
|
51 |
+
"scale_by": 10.0,
|
52 |
+
"separate_out": false,
|
53 |
+
"resume_from": false,
|
54 |
+
"epochs": 20,
|
55 |
+
"trained_checkp": ""
|
56 |
+
}
|
UnCRtainTS/model/checkpoint/monotemporalL2/model.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b86c07efb75812e5f9564da3c6ad289be5f82d86f6aa837df37f48b01b11aabc
|
3 |
+
size 6825365
|
UnCRtainTS/model/ensemble_reconstruct.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Python script to obtain Deep Ensemble predictions by collecting each instance's pre-computed predictions.
|
3 |
+
Each member's predictions are first meant to be pre-computed via test_reconstruct.py, with the outputs exported,
|
4 |
+
and read again in this script. Online ensembling is currently not implemented as this may exceed hardware constraints.
|
5 |
+
For every ensemble member, the path to its output directory has to be specified in the list 'ensemble_paths'.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
from natsort import natsorted
|
14 |
+
|
15 |
+
dirname = os.path.dirname(os.path.abspath(__file__))
|
16 |
+
sys.path.append(os.path.dirname(dirname))
|
17 |
+
|
18 |
+
from data.dataLoader import SEN12MSCR, SEN12MSCRTS
|
19 |
+
from src.learning.metrics import img_metrics, avg_img_metrics
|
20 |
+
from train_reconstruct import recursive_todevice, compute_uce_auce, export, plot_img, save_results
|
21 |
+
|
22 |
+
epoch = 1
|
23 |
+
root = '/home/data/' # path to directory containing dataset
|
24 |
+
mode = 'test' # split to evaluate on
|
25 |
+
in_time = 3 # length of input time series
|
26 |
+
region = 'all' # region of areas of interest
|
27 |
+
max_samples = 1e9 # maximum count of samples to consider
|
28 |
+
uncertainty = 'both' # e.g. 'aleatoric', 'epistemic', 'both' --- only matters if ensemble==True
|
29 |
+
ensemble = True # whether to compute ensemble mean and var or not
|
30 |
+
pixelwise = True # whether to summarize errors and variances for image-based AUCE and UCE or keep pixel-based statistics
|
31 |
+
export_path = None # where to export ensemble statistics, set to None if no writing to files is desired
|
32 |
+
|
33 |
+
# define path to find the individual ensembe member's predictions in
|
34 |
+
ensemble_paths = [os.path.join(dirname, 'inference', f'diagonal_1/export/epoch_{epoch}/{mode}'),
|
35 |
+
os.path.join(dirname, 'inference', f'diagonal_2/export/epoch_{epoch}/{mode}'),
|
36 |
+
os.path.join(dirname, 'inference', f'diagonal_3/export/epoch_{epoch}/{mode}'),
|
37 |
+
os.path.join(dirname, 'inference', f'diagonal_4/export/epoch_{epoch}/{mode}'),
|
38 |
+
os.path.join(dirname, 'inference', f'diagonal_5/export/epoch_{epoch}/{mode}'),
|
39 |
+
]
|
40 |
+
|
41 |
+
n_ensemble = len(ensemble_paths)
|
42 |
+
print('Ensembling over model predictions:')
|
43 |
+
for instance in ensemble_paths: print(instance)
|
44 |
+
|
45 |
+
if export_path:
|
46 |
+
plot_dir = os.path.join(export_path, 'plots', f'epoch_{epoch}', f'{mode}')
|
47 |
+
export_dir = os.path.join(export_path, 'export', f'epoch_{epoch}', f'{mode}')
|
48 |
+
|
49 |
+
|
50 |
+
def prepare_data_multi(batch, device, batch_size=1, use_sar=True):
|
51 |
+
in_S2 = recursive_todevice(torch.tensor(batch['input']['S2']), device)
|
52 |
+
in_S2_td = recursive_todevice(torch.tensor(batch['input']['S2 TD']), device)
|
53 |
+
if batch_size>1: in_S2_td = torch.stack((in_S2_td)).T
|
54 |
+
in_m = recursive_todevice(torch.tensor(batch['input']['masks']), device)
|
55 |
+
target_S2 = recursive_todevice(torch.tensor(batch['target']['S2']), device)
|
56 |
+
y = target_S2
|
57 |
+
|
58 |
+
if use_sar:
|
59 |
+
in_S1 = recursive_todevice(torch.tensor(batch['input']['S1']), device)
|
60 |
+
in_S1_td = recursive_todevice(torch.tensor(batch['input']['S1 TD']), device)
|
61 |
+
if batch_size>1: in_S1_td = torch.stack((in_S1_td)).T
|
62 |
+
x = torch.cat((torch.stack(in_S1,dim=1), torch.stack(in_S2,dim=1)),dim=2)
|
63 |
+
dates = torch.stack((torch.tensor(in_S1_td),torch.tensor(in_S2_td))).float().mean(dim=0).to(device)
|
64 |
+
else:
|
65 |
+
x = in_S2 # torch.stack(in_S2,dim=1)
|
66 |
+
dates = torch.tensor(in_S2_td).float().to(device)
|
67 |
+
|
68 |
+
return x.unsqueeze(dim=0), y.unsqueeze(dim=0), in_m.unsqueeze(dim=0), dates
|
69 |
+
|
70 |
+
|
71 |
+
def main():
|
72 |
+
|
73 |
+
# list all predictions of the first ensemble member
|
74 |
+
dataPath = ensemble_paths[0]
|
75 |
+
samples = natsorted([os.path.join(dataPath, f) for f in os.listdir(dataPath) if (os.path.isfile(os.path.join(dataPath, f)) and "_pred.npy" in f)])
|
76 |
+
|
77 |
+
# collect sample-averaged uncertainties and errors
|
78 |
+
img_meter = avg_img_metrics()
|
79 |
+
vars_aleatoric = []
|
80 |
+
errs, errs_se, errs_ae = [], [], []
|
81 |
+
|
82 |
+
import_data_path = os.path.join(os.getcwd(), 'util', 'precomputed', f'generic_{in_time}_{mode}_{region}_s2cloudless_mask.npy')
|
83 |
+
import_data_path = import_data_path if os.path.isfile(import_data_path) else None
|
84 |
+
dt_test = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=mode, region=region, sample_type="cloudy_cloudfree" , n_input_samples=in_time, import_data_path=import_data_path)
|
85 |
+
if len(dt_test.paths) != len(samples): raise AssertionError
|
86 |
+
|
87 |
+
# iterate over the ensemble member's mean predictions
|
88 |
+
for idx, sample_mean in enumerate(tqdm(samples)):
|
89 |
+
if idx >= max_samples: break # exceeded desired sample count
|
90 |
+
|
91 |
+
# fetch target data and cloud masks of idx-th sample from
|
92 |
+
batch = dt_test.getsample(idx) # ... in order to compute metrics
|
93 |
+
x, y, in_m, _ = prepare_data_multi(batch, 'cuda', batch_size=1, use_sar=False)
|
94 |
+
|
95 |
+
try:
|
96 |
+
mean, var = [], []
|
97 |
+
for path in ensemble_paths: # for each ensemble member ...
|
98 |
+
# ... load the member's mean predictions and ...
|
99 |
+
mean.append(np.load(os.path.join(path, os.path.basename(sample_mean))))
|
100 |
+
# ... load the member's covariance or var predictions
|
101 |
+
sample_var = sample_mean.replace('_pred', '_covar')
|
102 |
+
if not os.path.isfile(os.path.join(path, os.path.basename(sample_var))):
|
103 |
+
sample_var = sample_mean.replace('_pred', '_var')
|
104 |
+
var.append(np.load(os.path.join(path, os.path.basename(sample_var))))
|
105 |
+
except:
|
106 |
+
# skip any sample for which not all members provide predictions
|
107 |
+
# (note: we also next'ed the dataloader's sample already)
|
108 |
+
print(f'Skipped sample {idx}, missing data.')
|
109 |
+
continue
|
110 |
+
mean, var = np.array(mean), np.array(var)
|
111 |
+
|
112 |
+
# get the variances from the covariance matrix
|
113 |
+
if len(var.shape) > 4: # loaded covariance matrix
|
114 |
+
var = np.moveaxis(np.diagonal(var, axis1=1, axis2=2), -1, 1)
|
115 |
+
|
116 |
+
# combine predictions
|
117 |
+
|
118 |
+
if ensemble:
|
119 |
+
# get ensemble estimate and epistemic uncertainty,
|
120 |
+
# approximate 1 Gaussian by mixture parameter ensembling
|
121 |
+
mean_ensemble = 1/n_ensemble * np.sum(mean, axis=0)
|
122 |
+
|
123 |
+
if uncertainty == 'aleatoric':
|
124 |
+
# average the members' aleatoric uncertainties
|
125 |
+
var_ensemble = 1/n_ensemble * np.sum(var, axis=0)
|
126 |
+
elif uncertainty == 'epistemic':
|
127 |
+
# compute average variance of ensemble predictions
|
128 |
+
var_ensemble = 1/n_ensemble * np.sum(mean**2, axis=0) - mean_ensemble**2
|
129 |
+
elif uncertainty == 'both':
|
130 |
+
# combine both
|
131 |
+
var_ensemble = 1/n_ensemble * np.sum(var + mean**2, axis=0) - mean_ensemble**2
|
132 |
+
else: raise NotImplementedError
|
133 |
+
else: mean_ensemble, var_ensemble = mean[0], var[0]
|
134 |
+
|
135 |
+
mean_ensemble = torch.tensor(mean_ensemble).cuda()
|
136 |
+
var_ensemble = torch.tensor(var_ensemble).cuda()
|
137 |
+
|
138 |
+
# compute test metrics on ensemble prediction
|
139 |
+
extended_metrics = img_metrics(y[0], mean_ensemble.unsqueeze(dim=0),
|
140 |
+
var=var_ensemble.unsqueeze(dim=0),
|
141 |
+
pixelwise=pixelwise)
|
142 |
+
img_meter.add(extended_metrics) # accumulate performances over the entire split
|
143 |
+
|
144 |
+
if pixelwise: # collect variances and errors
|
145 |
+
vars_aleatoric.extend(extended_metrics['pixelwise var'])
|
146 |
+
errs.extend(extended_metrics['pixelwise error'])
|
147 |
+
errs_se.extend(extended_metrics['pixelwise se'])
|
148 |
+
errs_ae.extend(extended_metrics['pixelwise ae'])
|
149 |
+
else:
|
150 |
+
vars_aleatoric.append(extended_metrics['mean var'])
|
151 |
+
errs.append(extended_metrics['error'])
|
152 |
+
errs_se.append(extended_metrics['mean se'])
|
153 |
+
errs_ae.append(extended_metrics['mean ae'])
|
154 |
+
|
155 |
+
if export_path: # plot and export ensemble predictions
|
156 |
+
plot_img(mean_ensemble.unsqueeze(dim=0), 'pred', plot_dir, file_id=idx)
|
157 |
+
plot_img(x[0], 'in', plot_dir, file_id=idx)
|
158 |
+
plot_img(var_ensemble.mean(dim=0, keepdims=True).expand(3, *var_ensemble.shape[1:]).unsqueeze(dim=0), 'var', plot_dir, file_id=idx)
|
159 |
+
export(mean_ensemble[None], 'pred', export_dir, file_id=idx)
|
160 |
+
export(var_ensemble[None], 'var', export_dir, file_id=idx)
|
161 |
+
|
162 |
+
|
163 |
+
# compute UCE and AUCE
|
164 |
+
uce_l2, auce_l2 = compute_uce_auce(vars_aleatoric, errs, len(vars_aleatoric), percent=5, l2=True, mode=mode, step=0)
|
165 |
+
|
166 |
+
# no need for a running mean here
|
167 |
+
img_meter.value()['UCE SE'] = uce_l2.cpu().numpy().item()
|
168 |
+
img_meter.value()['AUCE SE'] = auce_l2.cpu().numpy().item()
|
169 |
+
|
170 |
+
print(f'{mode} split image metrics: {img_meter.value()}')
|
171 |
+
if export_path:
|
172 |
+
np.save(os.path.join(export_path, f'pred_var_{uncertainty}.npy'), vars_aleatoric)
|
173 |
+
np.save(os.path.join(export_path, 'errors.npy'), errs)
|
174 |
+
save_results(img_meter.value(), export_path, split=mode)
|
175 |
+
print(f'Exported predictions to path {export_path}')
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
main()
|
180 |
+
exit()
|
UnCRtainTS/model/inference/diffcr_bs32_epoch17/conf.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model": "uncrtaints",
|
3 |
+
"experiment_name": "diffcr_bs32_epoch17",
|
4 |
+
"res_dir": "./inference",
|
5 |
+
"plot_every": -1,
|
6 |
+
"export_every": 1,
|
7 |
+
"resume_at": -1,
|
8 |
+
"encoder_widths": [
|
9 |
+
128
|
10 |
+
],
|
11 |
+
"decoder_widths": [
|
12 |
+
128,
|
13 |
+
128,
|
14 |
+
128,
|
15 |
+
128,
|
16 |
+
128
|
17 |
+
],
|
18 |
+
"out_conv": [
|
19 |
+
13
|
20 |
+
],
|
21 |
+
"mean_nonLinearity": true,
|
22 |
+
"var_nonLinearity": "softplus",
|
23 |
+
"agg_mode": "att_group",
|
24 |
+
"encoder_norm": "group",
|
25 |
+
"decoder_norm": "batch",
|
26 |
+
"block_type": "mbconv",
|
27 |
+
"padding_mode": "reflect",
|
28 |
+
"pad_value": 0,
|
29 |
+
"n_head": 16,
|
30 |
+
"d_model": 256,
|
31 |
+
"positional_encoding": true,
|
32 |
+
"d_k": 4,
|
33 |
+
"low_res_size": 32,
|
34 |
+
"use_v": false,
|
35 |
+
"num_workers": 0,
|
36 |
+
"rdm_seed": 1,
|
37 |
+
"device": "cuda:6",
|
38 |
+
"display_step": 10,
|
39 |
+
"loss": "MGNLL",
|
40 |
+
"resume_from": false,
|
41 |
+
"unfreeze_after": 0,
|
42 |
+
"epochs": 20,
|
43 |
+
"batch_size": 32,
|
44 |
+
"chunk_size": null,
|
45 |
+
"lr": 0.01,
|
46 |
+
"gamma": 1.0,
|
47 |
+
"val_every": 1,
|
48 |
+
"val_after": 0,
|
49 |
+
"use_sar": true,
|
50 |
+
"pretrain": true,
|
51 |
+
"input_t": 1,
|
52 |
+
"ref_date": "2014-04-03",
|
53 |
+
"sample_type": "pretrain",
|
54 |
+
"vary_samples": true,
|
55 |
+
"min_cov": 0.0,
|
56 |
+
"max_cov": 1.0,
|
57 |
+
"root1": "/home/data/SEN12MSCRTS",
|
58 |
+
"root2": "/home/data/SEN12MSCRTS",
|
59 |
+
"root3": "data2/SEN12MSCR",
|
60 |
+
"precomputed": "/home/code/UnCRtainTS/util/precomputed",
|
61 |
+
"region": "all",
|
62 |
+
"max_samples_count": 1000000000,
|
63 |
+
"max_samples_frac": 1.0,
|
64 |
+
"profile": false,
|
65 |
+
"trained_checkp": "",
|
66 |
+
"covmode": "diag",
|
67 |
+
"scale_by": 1.0,
|
68 |
+
"separate_out": false,
|
69 |
+
"weight_folder": "checkpoint/",
|
70 |
+
"use_custom": false,
|
71 |
+
"load_config": "",
|
72 |
+
"pid": 2049339
|
73 |
+
}
|
UnCRtainTS/model/inference/monotemporalL2/conf.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model": "uncrtaints",
|
3 |
+
"encoder_widths": [
|
4 |
+
128
|
5 |
+
],
|
6 |
+
"decoder_widths": [
|
7 |
+
128,
|
8 |
+
128,
|
9 |
+
128,
|
10 |
+
128,
|
11 |
+
128
|
12 |
+
],
|
13 |
+
"out_conv": [
|
14 |
+
13
|
15 |
+
],
|
16 |
+
"mean_nonLinearity": true,
|
17 |
+
"var_nonLinearity": "softplus",
|
18 |
+
"use_sar": true,
|
19 |
+
"agg_mode": "att_group",
|
20 |
+
"encoder_norm": "group",
|
21 |
+
"decoder_norm": "batch",
|
22 |
+
"n_head": 1,
|
23 |
+
"d_model": 256,
|
24 |
+
"use_v": false,
|
25 |
+
"positional_encoding": true,
|
26 |
+
"d_k": 4,
|
27 |
+
"experiment_name": "monotemporalL2",
|
28 |
+
"lr": 0.001,
|
29 |
+
"gamma": 0.8,
|
30 |
+
"ref_date": "2014-04-03",
|
31 |
+
"pad_value": 0,
|
32 |
+
"padding_mode": "reflect",
|
33 |
+
"val_every": 1,
|
34 |
+
"val_after": 0,
|
35 |
+
"pretrain": true,
|
36 |
+
"sample_type": "pretrain",
|
37 |
+
"vary_samples": true,
|
38 |
+
"max_samples": 1000000000,
|
39 |
+
"input_size": 256,
|
40 |
+
"loss": "l2",
|
41 |
+
"covmode": "diag",
|
42 |
+
"scale_by": 10.0,
|
43 |
+
"separate_out": false,
|
44 |
+
"resume_from": false,
|
45 |
+
"epochs": 20,
|
46 |
+
"res_dir": "./inference",
|
47 |
+
"plot_every": -1,
|
48 |
+
"export_every": 1,
|
49 |
+
"resume_at": -1,
|
50 |
+
"device": "cuda:6",
|
51 |
+
"display_step": 10,
|
52 |
+
"batch_size": 128,
|
53 |
+
"input_t": 1,
|
54 |
+
"min_cov": 0.0,
|
55 |
+
"max_cov": 1.0,
|
56 |
+
"root1": "/home/data/SEN12MSCRTS",
|
57 |
+
"root2": "/home/data/SEN12MSCRTS",
|
58 |
+
"root3": "data2/SEN12MSCR",
|
59 |
+
"region": "all",
|
60 |
+
"max_samples_count": 1000000000,
|
61 |
+
"trained_checkp": "",
|
62 |
+
"weight_folder": "checkpoint/",
|
63 |
+
"pid": 2973877,
|
64 |
+
"block_type": "mbconv",
|
65 |
+
"low_res_size": 32,
|
66 |
+
"num_workers": 0,
|
67 |
+
"rdm_seed": 1,
|
68 |
+
"unfreeze_after": 0,
|
69 |
+
"chunk_size": null,
|
70 |
+
"precomputed": "/home/code/UnCRtainTS/util/precomputed",
|
71 |
+
"max_samples_frac": 1.0,
|
72 |
+
"profile": false,
|
73 |
+
"use_custom": false,
|
74 |
+
"load_config": ""
|
75 |
+
}
|
UnCRtainTS/model/inference/monotemporalL2/test_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.038980784788900186,
|
3 |
+
"MAE": 0.02744151706378001,
|
4 |
+
"PSNR": 28.900039257648842,
|
5 |
+
"SAM": 8.320397798952893,
|
6 |
+
"SSIM": 0.8797316785507024,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/parse_args.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
S2_BANDS = 13
|
4 |
+
|
5 |
+
def create_parser(mode='train'):
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
# model parameters
|
8 |
+
parser.add_argument(
|
9 |
+
"--model",
|
10 |
+
default='uncrtaints', # e.g. 'unet', 'utae', 'uncrtaints',
|
11 |
+
type=str,
|
12 |
+
help="Type of architecture to use. Can be one of: (utae/unet3d/fpn/convlstm/convgru/uconvlstm/buconvlstm)",
|
13 |
+
)
|
14 |
+
parser.add_argument("--experiment_name", default='my_first_experiment', help="Name of the current experiment",)
|
15 |
+
|
16 |
+
# fast switching between default arguments, depending on train versus test mode
|
17 |
+
if mode=='train':
|
18 |
+
parser.add_argument("--res_dir", default="./results", help="Path to where the results are stored, e.g. ./results for training or ./inference for testing",)
|
19 |
+
parser.add_argument("--plot_every", default=-1, type=int, help="Interval (in items) of exporting plots at validation or test time. Set -1 to disable")
|
20 |
+
parser.add_argument("--export_every", default=-1, type=int, help="Interval (in items) of exporting data at validation or test time. Set -1 to disable")
|
21 |
+
parser.add_argument("--resume_at", default=0, type=int, help="Epoch to resume training from (may re-weight --lr in the optimizer) or epoch to load checkpoint from at test time")
|
22 |
+
elif mode=='test':
|
23 |
+
parser.add_argument("--res_dir", default="./inference", type=str, help="Path to directory where results are written.")
|
24 |
+
parser.add_argument("--plot_every", default=-1, type=int, help="Interval (in items) of exporting plots at validation or test time. Set -1 to disable")
|
25 |
+
parser.add_argument("--export_every", default=1, type=int, help="Interval (in items) of exporting data at validation or test time. Set -1 to disable")
|
26 |
+
parser.add_argument("--resume_at", default=-1, type=int, help="Epoch to load checkpoint from and run testing with (use -1 for best on validation split)")
|
27 |
+
|
28 |
+
parser.add_argument("--encoder_widths", default="[128]", type=str, help="e.g. [64,64,64,128] for U-TAE or [128] for UnCRtainTS")
|
29 |
+
parser.add_argument("--decoder_widths", default="[128,128,128,128,128]", type=str, help="e.g. [64,64,64,128] for U-TAE or [128,128,128,128,128] for UnCRtainTS")
|
30 |
+
parser.add_argument("--out_conv", default=f"[{S2_BANDS}]", help="output CONV, note: if inserting another layer then consider treating normalizations separately")
|
31 |
+
parser.add_argument("--mean_nonLinearity", dest="mean_nonLinearity", action="store_false", help="whether to apply a sigmoidal output nonlinearity to the mean prediction")
|
32 |
+
parser.add_argument("--var_nonLinearity", default="softplus", type=str, help="how to squash the network's variance outputs [relu | softplus | elu ]")
|
33 |
+
parser.add_argument("--agg_mode", default="att_group", type=str, help="type of temporal aggregation in L-TAE module")
|
34 |
+
parser.add_argument("--encoder_norm", default="group", type=str, help="e.g. 'group' (when using many channels) or 'instance' (for few channels)")
|
35 |
+
parser.add_argument("--decoder_norm", default="batch", type=str, help="e.g. 'group' (when using many channels) or 'instance' (for few channels)")
|
36 |
+
parser.add_argument("--block_type", default="mbconv", type=str, help="type of CONV block to use [residual | mbconv]")
|
37 |
+
parser.add_argument("--padding_mode", default="reflect", type=str)
|
38 |
+
parser.add_argument("--pad_value", default=0, type=float)
|
39 |
+
|
40 |
+
# attention-specific parameters
|
41 |
+
parser.add_argument("--n_head", default=16, type=int, help="default value of 16, 4 for debugging")
|
42 |
+
parser.add_argument("--d_model", default=256, type=int, help="layers in L-TAE, default value of 256")
|
43 |
+
parser.add_argument("--positional_encoding", dest="positional_encoding", action="store_false", help="whether to use positional encoding or not")
|
44 |
+
parser.add_argument("--d_k", default=4, type=int)
|
45 |
+
parser.add_argument("--low_res_size", default=32, type=int, help="resolution to downsample to")
|
46 |
+
parser.add_argument("--use_v", dest="use_v", action="store_true", help="whether to use values v or not")
|
47 |
+
|
48 |
+
# set-up parameters
|
49 |
+
parser.add_argument("--num_workers", default=0, type=int, help="Number of data loading workers")
|
50 |
+
parser.add_argument("--rdm_seed", default=1, type=int, help="Random seed")
|
51 |
+
parser.add_argument("--device",default="cuda",type=str,help="Name of device to use for tensor computations (cuda/cpu)",)
|
52 |
+
parser.add_argument("--display_step", default=10, type=int, help="Interval in batches between display of training metrics",)
|
53 |
+
|
54 |
+
# training parameters
|
55 |
+
parser.add_argument("--loss", default="MGNLL", type=str, help="Image reconstruction loss to utilize [l1|l2|GNLL|MGNLL].")
|
56 |
+
parser.add_argument("--resume_from", dest="resume_from", action="store_true", help="resume training acc. to JSON in --experiment_name and *.pth chckp in --trained_checkp")
|
57 |
+
parser.add_argument("--unfreeze_after", default=0, type=int, help="When to unfreeze ALL weights for training")
|
58 |
+
parser.add_argument("--epochs", default=20, type=int, help="Number of epochs to train")
|
59 |
+
parser.add_argument("--batch_size", default=4, type=int, help="Batch size")
|
60 |
+
parser.add_argument("--chunk_size", type=int, help="Size of vmap batches, this can be adjusted to accommodate for additional memory needs")
|
61 |
+
parser.add_argument("--lr", default=1e-2, type=float, help="Learning rate, e.g. 0.01")
|
62 |
+
parser.add_argument("--gamma", default=1.0, type=float, help="Learning rate decay parameter for scheduler")
|
63 |
+
parser.add_argument("--val_every", default=1, type=int, help="Interval in epochs between two validation steps.")
|
64 |
+
parser.add_argument("--val_after", default=0, type=int, help="Do validation only after that many epochs.")
|
65 |
+
|
66 |
+
# flags specific to SEN12MS-CR and SEN12MS-CR-TS
|
67 |
+
parser.add_argument("--use_sar", dest="use_sar", action="store_true", help="whether to use SAR or not")
|
68 |
+
parser.add_argument("--pretrain", dest="pretrain", action="store_true", help="whether to perform pretraining on SEN12MS-CR or training on SEN12MS-CR-TS")
|
69 |
+
parser.add_argument("--input_t", default=3, type=int, help="number of input time points to sample, unet3d needs at least 4 time points")
|
70 |
+
parser.add_argument("--ref_date", default="2014-04-03", type=str, help="reference date for Sentinel observations")
|
71 |
+
parser.add_argument("--sample_type", default="cloudy_cloudfree", type=str, help="type of samples returned [cloudy_cloudfree | generic]")
|
72 |
+
parser.add_argument("--vary_samples", dest="vary_samples", action="store_false", help="whether to sample different time points across epochs or not")
|
73 |
+
parser.add_argument("--min_cov", default=0.0, type=float, help="The minimum cloud coverage to accept per input sample at train time. Gets overwritten by --vary_samples")
|
74 |
+
parser.add_argument("--max_cov", default=1.0, type=float, help="The maximum cloud coverage to accept per input sample at train time. Gets overwritten by --vary_samples")
|
75 |
+
parser.add_argument("--root1", default='/home/data/SEN12MSCRTS', type=str, help="path to your copy of SEN12MS-CR-TS")
|
76 |
+
parser.add_argument("--root2", default='/home/data/SEN12MSCRTS', type=str, help="path to your copy of SEN12MS-CR-TS validation & test splits")
|
77 |
+
parser.add_argument("--root3", default='/home/data/SEN12MSCR', type=str, help="path to your copy of SEN12MS-CR for pretraining")
|
78 |
+
parser.add_argument("--precomputed", default='/home/code/UnCRtainTS/util/precomputed', type=str, help="path to pre-computed cloud statistics")
|
79 |
+
parser.add_argument("--region", default="all", type=str, help="region to (sub-)sample ROI from [all|africa|america|asiaEast|asiaWest|europa]")
|
80 |
+
parser.add_argument("--max_samples_count", default=int(1e9), type=int, help="count of data (sub-)samples to take")
|
81 |
+
parser.add_argument("--max_samples_frac", default=1.0, type=float, help="fraction of data (sub-)samples to take")
|
82 |
+
parser.add_argument("--profile", dest="profile", action="store_true", help="whether to profile code or not")
|
83 |
+
parser.add_argument("--trained_checkp", default="", type=str, help="Path to loading a pre-trained network *.pth file, rather than initializing weights randomly")
|
84 |
+
|
85 |
+
# flags specific to uncertainty modeling
|
86 |
+
parser.add_argument("--covmode", default='diag', type=str, help="covariance matrix type [uni|iso|diag].")
|
87 |
+
parser.add_argument("--scale_by", default=1.0, type=float, help="rescale data within model, e.g. to [0,10]")
|
88 |
+
parser.add_argument("--separate_out", dest="separate_out", action="store_true", help="whether to separately process mean and variance predictions or in a shared layer")
|
89 |
+
|
90 |
+
# flags specific for testing
|
91 |
+
parser.add_argument("--weight_folder", type=str, default="./results", help="Path to the main folder containing the pre-trained weights")
|
92 |
+
parser.add_argument("--use_custom", dest="use_custom", action="store_true", help="whether to test on individually specified patches or not")
|
93 |
+
parser.add_argument("--load_config", default='', type=str, help="path of conf.json file to load")
|
94 |
+
|
95 |
+
return parser
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/conf.json
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model": "uncrtaints",
|
3 |
+
"experiment_name": "diffcr_bs32_lr15e-4",
|
4 |
+
"res_dir": "./results",
|
5 |
+
"plot_every": -1,
|
6 |
+
"export_every": -1,
|
7 |
+
"resume_at": 0,
|
8 |
+
"encoder_widths": [
|
9 |
+
128
|
10 |
+
],
|
11 |
+
"decoder_widths": [
|
12 |
+
128,
|
13 |
+
128,
|
14 |
+
128,
|
15 |
+
128,
|
16 |
+
128
|
17 |
+
],
|
18 |
+
"out_conv": [
|
19 |
+
13
|
20 |
+
],
|
21 |
+
"mean_nonLinearity": false,
|
22 |
+
"var_nonLinearity": "softplus",
|
23 |
+
"agg_mode": "att_group",
|
24 |
+
"encoder_norm": "group",
|
25 |
+
"decoder_norm": "batch",
|
26 |
+
"block_type": "mbconv",
|
27 |
+
"padding_mode": "reflect",
|
28 |
+
"pad_value": 0.0,
|
29 |
+
"n_head": 1,
|
30 |
+
"d_model": 256,
|
31 |
+
"positional_encoding": false,
|
32 |
+
"d_k": 4,
|
33 |
+
"low_res_size": 32,
|
34 |
+
"use_v": false,
|
35 |
+
"num_workers": 16,
|
36 |
+
"rdm_seed": 1,
|
37 |
+
"device": "cuda:0",
|
38 |
+
"display_step": 10,
|
39 |
+
"loss": "l2",
|
40 |
+
"resume_from": false,
|
41 |
+
"unfreeze_after": 0,
|
42 |
+
"epochs": 100,
|
43 |
+
"batch_size": 32,
|
44 |
+
"chunk_size": null,
|
45 |
+
"lr": 0.0005,
|
46 |
+
"gamma": 0.8,
|
47 |
+
"val_every": 1,
|
48 |
+
"val_after": 0,
|
49 |
+
"use_sar": true,
|
50 |
+
"pretrain": true,
|
51 |
+
"input_t": 1,
|
52 |
+
"ref_date": "2014-04-03",
|
53 |
+
"sample_type": "pretrain",
|
54 |
+
"vary_samples": false,
|
55 |
+
"min_cov": 0.0,
|
56 |
+
"max_cov": 1.0,
|
57 |
+
"root1": "/home/data/SEN12MSCRTS",
|
58 |
+
"root2": "/home/data/SEN12MSCRTS",
|
59 |
+
"root3": "data2/SEN12MSCR",
|
60 |
+
"precomputed": "/home/code/UnCRtainTS/util/precomputed",
|
61 |
+
"region": "all",
|
62 |
+
"max_samples_count": 1000000000,
|
63 |
+
"max_samples_frac": 1.0,
|
64 |
+
"profile": false,
|
65 |
+
"trained_checkp": "",
|
66 |
+
"covmode": "diag",
|
67 |
+
"scale_by": 10.0,
|
68 |
+
"separate_out": false,
|
69 |
+
"weight_folder": "./results",
|
70 |
+
"use_custom": false,
|
71 |
+
"load_config": "",
|
72 |
+
"pid": 2877152,
|
73 |
+
"N_params": 19322381
|
74 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35fd168203bc0ba6bf830851778d4958658be218c9b17e047592737dc68e49b2
|
3 |
+
size 213825786
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_11.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c180b52ca34e85f3f29c3296b179723975c0326f6f3fc5efb3ea885badfdb92
|
3 |
+
size 213833726
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_36.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6b156aee172142a2e5a530f7cef87675b8871985a410aefd03ae11004260a58
|
3 |
+
size 213833726
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_10_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.029363970156402505,
|
3 |
+
"MAE": 0.020022216210124334,
|
4 |
+
"PSNR": 31.572881788434003,
|
5 |
+
"SAM": 5.883645729394377,
|
6 |
+
"SSIM": 0.8995029243551826,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_11_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.028865320518000174,
|
3 |
+
"MAE": 0.019421607403492344,
|
4 |
+
"PSNR": 31.768579506695946,
|
5 |
+
"SAM": 5.820518317864266,
|
6 |
+
"SSIM": 0.9015902346229956,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_12_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.028450740317855727,
|
3 |
+
"MAE": 0.019143247260043655,
|
4 |
+
"PSNR": 31.91159520461337,
|
5 |
+
"SAM": 5.6859444906808125,
|
6 |
+
"SSIM": 0.9034519373110392,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_13_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.02858291424103222,
|
3 |
+
"MAE": 0.01927231159917463,
|
4 |
+
"PSNR": 31.870293692465435,
|
5 |
+
"SAM": 5.595422786736038,
|
6 |
+
"SSIM": 0.9039587730968367,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_14_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.028029547319395425,
|
3 |
+
"MAE": 0.018867645016614896,
|
4 |
+
"PSNR": 32.0631246642833,
|
5 |
+
"SAM": 5.500893342840581,
|
6 |
+
"SSIM": 0.9053583295002058,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_15_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.028002896943423252,
|
3 |
+
"MAE": 0.018806874698693236,
|
4 |
+
"PSNR": 32.124292399735644,
|
5 |
+
"SAM": 5.547215727374945,
|
6 |
+
"SSIM": 0.9059335617932202,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_16_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027591812040673,
|
3 |
+
"MAE": 0.018454319715277293,
|
4 |
+
"PSNR": 32.246636826172086,
|
5 |
+
"SAM": 5.4618273344745045,
|
6 |
+
"SSIM": 0.9069541601278078,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_17_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027644014190348384,
|
3 |
+
"MAE": 0.018551536467928936,
|
4 |
+
"PSNR": 32.24370574226725,
|
5 |
+
"SAM": 5.461317134667159,
|
6 |
+
"SSIM": 0.9069266391564054,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_18_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027595730930196957,
|
3 |
+
"MAE": 0.018533841286411494,
|
4 |
+
"PSNR": 32.25086082007444,
|
5 |
+
"SAM": 5.458732208793296,
|
6 |
+
"SSIM": 0.9071483427338218,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_19_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027390471058173046,
|
3 |
+
"MAE": 0.018348841181363623,
|
4 |
+
"PSNR": 32.342438679101114,
|
5 |
+
"SAM": 5.419562214702756,
|
6 |
+
"SSIM": 0.9077709371763801,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_1_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.044388687314263,
|
3 |
+
"MAE": 0.03238296521911765,
|
4 |
+
"PSNR": 27.6217471128898,
|
5 |
+
"SAM": 10.4182609624991,
|
6 |
+
"SSIM": 0.763119293530315,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_20_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.02744433999974496,
|
3 |
+
"MAE": 0.018402347492774685,
|
4 |
+
"PSNR": 32.3317267860638,
|
5 |
+
"SAM": 5.421831781584368,
|
6 |
+
"SSIM": 0.9078204176338086,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_21_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.02734196908368276,
|
3 |
+
"MAE": 0.018271456095607545,
|
4 |
+
"PSNR": 32.37505691355452,
|
5 |
+
"SAM": 5.414982792168764,
|
6 |
+
"SSIM": 0.9081676423451513,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_22_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027245160590102936,
|
3 |
+
"MAE": 0.018224454178698474,
|
4 |
+
"PSNR": 32.41090423394659,
|
5 |
+
"SAM": 5.398054870446574,
|
6 |
+
"SSIM": 0.9083186444307726,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_23_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.02729571228139796,
|
3 |
+
"MAE": 0.01828788889602147,
|
4 |
+
"PSNR": 32.395462209747514,
|
5 |
+
"SAM": 5.393570169404665,
|
6 |
+
"SSIM": 0.9081527106459147,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_24_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027244927966896305,
|
3 |
+
"MAE": 0.018264471261733865,
|
4 |
+
"PSNR": 32.41291005914153,
|
5 |
+
"SAM": 5.392408405063824,
|
6 |
+
"SSIM": 0.908316025697424,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_25_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027284221707286997,
|
3 |
+
"MAE": 0.018280692685820887,
|
4 |
+
"PSNR": 32.40145441845609,
|
5 |
+
"SAM": 5.39122454551732,
|
6 |
+
"SSIM": 0.9083838710484244,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_26_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.02726639931698313,
|
3 |
+
"MAE": 0.01826652071185466,
|
4 |
+
"PSNR": 32.41051288700818,
|
5 |
+
"SAM": 5.3847173492672225,
|
6 |
+
"SSIM": 0.9083935374216957,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_27_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027239437090096214,
|
3 |
+
"MAE": 0.018245351712652874,
|
4 |
+
"PSNR": 32.42091501759734,
|
5 |
+
"SAM": 5.386129859399605,
|
6 |
+
"SSIM": 0.9085140088899692,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_28_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027215027071061892,
|
3 |
+
"MAE": 0.018221582465008372,
|
4 |
+
"PSNR": 32.42838631419994,
|
5 |
+
"SAM": 5.384088766455628,
|
6 |
+
"SSIM": 0.9085492332580157,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_29_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027242751470667303,
|
3 |
+
"MAE": 0.01824946738290542,
|
4 |
+
"PSNR": 32.41849800725127,
|
5 |
+
"SAM": 5.381040883963731,
|
6 |
+
"SSIM": 0.9085034823882493,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_2_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.04118385424327889,
|
3 |
+
"MAE": 0.029236674496597642,
|
4 |
+
"PSNR": 28.28133535765438,
|
5 |
+
"SAM": 8.635967488469356,
|
6 |
+
"SSIM": 0.8477815685430351,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_30_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027199031498741918,
|
3 |
+
"MAE": 0.01820931484350237,
|
4 |
+
"PSNR": 32.43442219682148,
|
5 |
+
"SAM": 5.381641721873335,
|
6 |
+
"SSIM": 0.9086828993640171,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_31_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027213316269354252,
|
3 |
+
"MAE": 0.01822304026870589,
|
4 |
+
"PSNR": 32.43005099008614,
|
5 |
+
"SAM": 5.380723569767104,
|
6 |
+
"SSIM": 0.9086281500793518,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_32_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027218542108819188,
|
3 |
+
"MAE": 0.01823459587795115,
|
4 |
+
"PSNR": 32.429575243876165,
|
5 |
+
"SAM": 5.380705763062456,
|
6 |
+
"SSIM": 0.9086230123762922,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_33_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027201242979035374,
|
3 |
+
"MAE": 0.01821136096979529,
|
4 |
+
"PSNR": 32.4347129297312,
|
5 |
+
"SAM": 5.380752220457475,
|
6 |
+
"SSIM": 0.9086639188425286,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_34_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.02720954676425417,
|
3 |
+
"MAE": 0.01821668226351625,
|
4 |
+
"PSNR": 32.43351948486106,
|
5 |
+
"SAM": 5.378364858830772,
|
6 |
+
"SSIM": 0.9086556218518279,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_35_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027196117778619137,
|
3 |
+
"MAE": 0.018208952186562502,
|
4 |
+
"PSNR": 32.43714808439791,
|
5 |
+
"SAM": 5.377404463051186,
|
6 |
+
"SSIM": 0.9086787752794829,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_36_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.027197365488975264,
|
3 |
+
"MAE": 0.018207441006122857,
|
4 |
+
"PSNR": 32.437294496043016,
|
5 |
+
"SAM": 5.378433104420995,
|
6 |
+
"SSIM": 0.908667840788918,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|
UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_3_metrics.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RMSE": 0.03848631408196211,
|
3 |
+
"MAE": 0.02771346952286548,
|
4 |
+
"PSNR": 28.855159436197052,
|
5 |
+
"SAM": 7.957361651426766,
|
6 |
+
"SSIM": 0.8585438828465877,
|
7 |
+
"error": NaN,
|
8 |
+
"mean se": NaN,
|
9 |
+
"mean ae": NaN,
|
10 |
+
"mean var": NaN
|
11 |
+
}
|