diff --git a/UnCRtainTS/.gitignore b/UnCRtainTS/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aaa6f95b693fa34e59731d43c1ef7c5af8c73e70 --- /dev/null +++ b/UnCRtainTS/.gitignore @@ -0,0 +1,6 @@ +*.npy +logs +model/inference +model/checkpoint +model/results +*_pycache_* \ No newline at end of file diff --git a/UnCRtainTS/Dockerfile b/UnCRtainTS/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..c3303e706165eaabe06673ff28897faddc80bf1d --- /dev/null +++ b/UnCRtainTS/Dockerfile @@ -0,0 +1,35 @@ +FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime + +# install dependencies +#RUN pip install functorch +# '' may actually no longer be needed from torch 1.13 on +RUN pip install cupy-cuda112 +RUN conda install -c conda-forge cupy +#RUN conda install pytorch torchvision cudatoolkit=11.7 -c pytorch +RUN pip install opencv-python +# RUN conda install -c conda-forge opencv +RUN pip install scipy rasterio natsort matplotlib scikit-image tqdm pandas +RUN pip install Pillow dominate visdom tensorboard +RUN pip install kornia torchgeometry torchmetrics torchnet segmentation-models-pytorch +RUN pip install s2cloudless +# see: https://github.com/sentinel-hub/sentinel2-cloud-detector/issues/17 +RUN pip install numpy==1.21.6 + +RUN apt-get -y update +RUN apt-get -y install git +RUN pip install -U 'git+https://github.com/facebookresearch/fvcore' + +# just in case some last-minute changes are needed +RUN apt-get install nano + +# bake repository into dockerfile +RUN mkdir -p ./data +RUN mkdir -p ./model +RUN mkdir -p ./util + +ADD data ./data +ADD model ./model +ADD util ./util +ADD . ./ + +WORKDIR /workspace/model \ No newline at end of file diff --git a/UnCRtainTS/README.md b/UnCRtainTS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ad397c5ca7610e8a4ab0ad88c2da5fb6dc2a2948 --- /dev/null +++ b/UnCRtainTS/README.md @@ -0,0 +1,109 @@ +# UnCRtainTS: Uncertainty Quantification for Cloud Removal in Optical Satellite Time Series + +![banner](architecture.png) +> +> _This is the official repository for UnCRtainTS, a network for multi-temporal cloud removal in satellite data combining a novel attention-based architecture, and a formulation for multivariate uncertainty prediction. These two components combined set a new state-of-the-art performance in terms of image reconstruction on two public cloud removal datasets. Additionally, we show how the well-calibrated predicted uncertainties enable a precise control of the reconstruction quality._ +---- +This repository contains code accompanying the paper +> P. Ebel, V. Garnot, M. Schmitt, J. Wegner and X. X. Zhu. UnCRtainTS: Uncertainty Quantification for Cloud Removal in Optical Satellite Time Series. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, 2023. + +For additional information: + +* The publication is available in the [CVPRW Proceedings](https://openaccess.thecvf.com/content/CVPR2023W/EarthVision/papers/Ebel_UnCRtainTS_Uncertainty_Quantification_for_Cloud_Removal_in_Optical_Satellite_Time_CVPRW_2023_paper.pdf). +* The SEN12MS-CR-TS data set is accessible at the MediaTUM page [here](https://mediatum.ub.tum.de/1639953) (train split) and [here](https://mediatum.ub.tum.de/1659251) (test split). +* You can find additional information on this and related projects on the associated [cloud removal projects page](https://patrickTUM.github.io/cloud_removal/). +* For any further questions, please reach out to me here or via the credentials on my [website](https://pwjebel.com). +--- + +## Installation +### Dataset + +You can easily download the multi-temporal SEN12MS-CR-TS (and, optionally, the mono-temporal SEN12MS-CR) dataset via the shell script in [`./util/dl_data.sh`](https://github.com/PatrickTUM/UnCRtainTS/blob/main/util/dl_data.sh). Alternatively, you may download the SEN12MS-CR-TS data set (or parts of it) via the MediaTUM website [here](https://mediatum.ub.tum.de/1639953) (train split) and [here](https://mediatum.ub.tum.de/1659251) (test split), with further instructions provided in the dataset's own [dedicated repository](https://github.com/PatrickTUM/SEN12MS-CR-TS#dataset). + +### Code +Clone this repository via `git clone https://github.com/PatrickTUM/UnCRtainTS.git`. + +and set up the Python environment via + +```bash +conda env create --file environment.yaml +conda activate uncrtaints +``` + +Alternatively, you may install all that's needed via +```bash +pip install -r requirements.txt +``` +or by building a Docker image of `Dockerfile` and deploying a container. + +The code is written in Python 3 and uses PyTorch $\geq$ 2.0. It is strongly recommended to run the code with CUDA and GPU support. The code has been developed and deployed in Ubuntu 20 LTS and should be able to run in any comparable OS. + +--- + +## Usage +### Dataset +If you already have your own model in place or wish to build one on the SEN12MS-CR-TS data loader for training and testing, the data loader can be used as a stand-alone script as demonstrated in `./standalone_dataloader.py`. This only requires the files `./data/dataLoader.py` (the actual data loader) and `./util/detect_cloudshadow.py` (if this type of cloud detector is chosen). + +For using the dataset as a stand-alone with your own model, loading multi-temporal multi-modal data from SEN12MS-CR-TS is as simple as + +``` python +import torch +from data.dataLoader import SEN12MSCRTS +dir_SEN12MSCRTS = '/path/to/your/SEN12MSCRTS' +sen12mscrts = SEN12MSCRTS(dir_SEN12MSCRTS, split='all', region='all', n_input_samples=3) +dataloader = torch.utils.data.DataLoader(sen12mscrts) + +for pdx, samples in enumerate(dataloader): print(samples['input'].keys()) +``` + +and, likewise, if you wish to (pre-)train on the mono-temporal multi-modal SEN12MS-CR dataset: + +``` python +import torch +from data.dataLoader import SEN12MSCR +dir_SEN12MSCR = '/path/to/your/SEN12MSCR' +sen12mscr = SEN12MSCR(dir_SEN12MSCR, split='all', region='all') +dataloader = torch.utils.data.DataLoader(sen12mscr) + +for pdx, samples in enumerate(dataloader): print(samples['input'].keys()) +``` + +Note that computing cloud masks on the fly, depending on the choice of cloud detection, may slow down data loading. For greater efficiency, files of pre-computed cloud coverage statistics can be +downloaded [here](https://u.pcloud.link/publink/show?code=kZXdbk0ZaAHNV2a5ofbB9UW4xCyCT0YFYAFk) or pre-computed via `./util/pre_compute_data_samples.py`, and then loaded with the `--precomputed /path/to/files/` flag. + +### Basic Commands +You can train a new model via +```bash +cd ./UnCRtainTS/model +python train_reconstruct.py --experiment_name my_first_experiment --root1 path/to/SEN12MSCRtrain --root2 path/to/SEN12MSCRtest --root3 path/to/SEN12MSCR --model uncrtaints --input_t 3 --region all --epochs 20 --lr 0.001 --batch_size 4 --gamma 1.0 --scale_by 10.0 --trained_checkp "" --loss MGNLL --covmode diag --var_nonLinearity softplus --display_step 10 --use_sar --block_type mbconv --n_head 16 --device cuda --res_dir ./results --rdm_seed 1 +``` +and you can test a (pre-)trained model via +```bash +python test_reconstruct.py --experiment_name my_first_experiment -root1 path/to/SEN12MSCRtrain --root2 path/to/SEN12MSCRtest --root3 path/to/SEN12MSCR --input_t 3 --region all --export_every 1 --res_dir ./inference --weight_folder ./results +``` + +For a list and description of all flags, please see the parser file `./model/parse_args.py`. To perform inference with pre-trained models, [here](https://u.pcloud.link/publink/show?code=kZsdbk0Z5Y2Y2UEm48XLwOvwSVlL8R2L3daV)'s where you can find the checkpoints. Every checkpoint is accompanied by a json file, documenting the flags set during training and expected to reproduce the model's behavior at test time. If pointing towards the exported configurations upon call, the correct settings get loaded automatically in the test script. Finally, following the exporting of model predictions via `test_reconstruct.py`, multiple models' outputs can be ensembled via `ensemble_reconstruct.py`, to obtain estimates of epistemic uncertainty. + +--- + + +## References + +If you use this code, our models or data set for your research, please cite [this](https://openaccess.thecvf.com/content/CVPR2023W/EarthVision/papers/Ebel_UnCRtainTS_Uncertainty_Quantification_for_Cloud_Removal_in_Optical_Satellite_Time_CVPRW_2023_paper.pdf) publication: +```bibtex +@inproceedings{UnCRtainTS, + title = {{UnCRtainTS: Uncertainty Quantification for Cloud Removal in Optical Satellite Time Series}}, + author = {Ebel, Patrick and Garnot, Vivien Sainte Fare and Schmitt, Michael and Wegner, Jan and Zhu, Xiao Xiang}, + booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops}, + year = {2023}, + organization = {IEEE}, + url = {"https://openaccess.thecvf.com/content/CVPR2023W/EarthVision/papers/Ebel_UnCRtainTS_Uncertainty_Quantification_for_Cloud_Removal_in_Optical_Satellite_Time_CVPRW_2023_paper.pdf"} +} +``` +You may also be interested in our related works, which you can discover on the accompanying [cloud removal projects website](https://patrickTUM.github.io/cloud_removal/). + + + +## Credits + +This code was originally based on the [UTAE](https://github.com/VSainteuf/utae-paps) and the [SEN12MS-CR-TS](https://github.com/PatrickTUM/SEN12MS-CR-TS) repositories. Thanks for making your code publicly available! We hope this repository will equally contribute to the development of future exciting work. diff --git a/UnCRtainTS/architecture.png b/UnCRtainTS/architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..5b6ffe4a030e6cf338de8a233429972ffb7d3b8f Binary files /dev/null and b/UnCRtainTS/architecture.png differ diff --git a/UnCRtainTS/data/__init__.py b/UnCRtainTS/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UnCRtainTS/data/dataLoader.py b/UnCRtainTS/data/dataLoader.py new file mode 100644 index 0000000000000000000000000000000000000000..d54d31e63cd504e1ecc5f9fc22e59044d765d409 --- /dev/null +++ b/UnCRtainTS/data/dataLoader.py @@ -0,0 +1,633 @@ +import os +import glob +import warnings +import numpy as np +from natsort import natsorted + +from datetime import datetime + +to_date = lambda string: datetime.strptime(string, "%Y-%m-%d") +S1_LAUNCH = to_date("2014-04-03") + +# s2cloudless: see https://github.com/sentinel-hub/sentinel2-cloud-detector +from s2cloudless import S2PixelCloudDetector + +import rasterio +from rasterio.merge import merge +from scipy.ndimage import gaussian_filter +from torch.utils.data import Dataset +# import sys + +# sys.path.append(".") +from util.detect_cloudshadow import get_cloud_mask, get_shadow_mask + + +# utility functions used in the dataloaders of SEN12MS-CR and SEN12MS-CR-TS +def read_tif(path_IMG): + tif = rasterio.open(path_IMG) + return tif + + +def read_img(tif): + return tif.read().astype(np.float32) + + +def rescale(img, oldMin, oldMax): + oldRange = oldMax - oldMin + img = (img - oldMin) / oldRange + return img + + +def process_MS(img, method): + if method == "default": + intensity_min, intensity_max = ( + 0, + 10000, + ) # define a reasonable range of MS intensities + img = np.clip( + img, intensity_min, intensity_max + ) # intensity clipping to a global unified MS intensity range + img = rescale( + img, intensity_min, intensity_max + ) # project to [0,1], preserve global intensities (across patches), gets mapped to [-1,+1] in wrapper + if method == "resnet": + intensity_min, intensity_max = ( + 0, + 10000, + ) # define a reasonable range of MS intensities + img = np.clip( + img, intensity_min, intensity_max + ) # intensity clipping to a global unified MS intensity range + img /= 2000 # project to [0,5], preserve global intensities (across patches) + img = np.nan_to_num(img) + return img + + +def process_SAR(img, method): + if method == "default": + dB_min, dB_max = -25, 0 # define a reasonable range of SAR dB + img = np.clip( + img, dB_min, dB_max + ) # intensity clipping to a global unified SAR dB range + img = rescale( + img, dB_min, dB_max + ) # project to [0,1], preserve global intensities (across patches), gets mapped to [-1,+1] in wrapper + if method == "resnet": + # project SAR to [0, 2] range + dB_min, dB_max = [-25.0, -32.5], [0, 0] + img = np.concatenate( + [ + ( + 2 + * (np.clip(img[0], dB_min[0], dB_max[0]) - dB_min[0]) + / (dB_max[0] - dB_min[0]) + )[None, ...], + ( + 2 + * (np.clip(img[1], dB_min[1], dB_max[1]) - dB_min[1]) + / (dB_max[1] - dB_min[1]) + )[None, ...], + ], + axis=0, + ) + img = np.nan_to_num(img) + return img + + +def get_cloud_cloudshadow_mask(img, cloud_threshold=0.2): + cloud_mask = get_cloud_mask(img, cloud_threshold, binarize=True) + shadow_mask = get_shadow_mask(img) + + # encode clouds and shadows as segmentation masks + cloud_cloudshadow_mask = np.zeros_like(cloud_mask) + cloud_cloudshadow_mask[shadow_mask < 0] = -1 + cloud_cloudshadow_mask[cloud_mask > 0] = 1 + + # label clouds and shadows + cloud_cloudshadow_mask[cloud_cloudshadow_mask != 0] = 1 + return cloud_cloudshadow_mask + + +# recursively apply function to nested dictionary +def iterdict(dictionary, fct): + for k, v in dictionary.items(): + if isinstance(v, dict): + dictionary[k] = iterdict(v, fct) + else: + dictionary[k] = fct(v) + return dictionary + + +def get_cloud_map(img, detector, instance=None): + # get cloud masks + img = np.clip(img, 0, 10000) + mask = np.ones((img.shape[-1], img.shape[-1])) + # note: if your model may suffer from dark pixel artifacts, + # you may consider adjusting these filtering parameters + if not (img.mean() < 1e-5 and img.std() < 1e-5): + if detector == "cloud_cloudshadow_mask": + threshold = 0.2 # set to e.g. 0.2 or 0.4 + mask = get_cloud_cloudshadow_mask(img, threshold) + elif detector == "s2cloudless_map": + threshold = 0.5 + mask = instance.get_cloud_probability_maps( + np.moveaxis(img / 10000, 0, -1)[None, ...] + )[0, ...] + mask[mask < threshold] = 0 + mask = gaussian_filter(mask, sigma=2) + elif detector == "s2cloudless_mask": + mask = instance.get_cloud_masks(np.moveaxis(img / 10000, 0, -1)[None, ...])[ + 0, ... + ] + else: + mask = np.ones((img.shape[-1], img.shape[-1])) + warnings.warn(f"Method {detector} not yet implemented!") + else: + warnings.warn(f"Encountered a blank sample, defaulting to cloudy mask.") + return mask.astype(np.float32) + + +# function to fetch paired data, which may differ in modalities or dates +def get_pairedS1(patch_list, root_dir, mod=None, time=None): + paired_list = [] + for patch in patch_list: + seed, roi, modality, time_number, fname = patch.split("/") + time = time_number if time is None else time # unless overwriting, ... + mod = ( + modality if mod is None else mod + ) # keep the patch list's original time and modality + n_patch = fname.split("patch_")[-1].split(".tif")[0] + paired_dir = os.path.join(seed, roi, mod.upper(), str(time)) + candidates = os.path.join( + root_dir, + paired_dir, + f"{mod}_{seed}_{roi}_ImgNo_{time}_*_patch_{n_patch}.tif", + ) + paired_list.append( + os.path.join(paired_dir, os.path.basename(glob.glob(candidates)[0])) + ) + return paired_list + + + + + + +""" SEN12MSCR data loader class, inherits from torch.utils.data.Dataset + + IN: + root: str, path to your copy of the SEN12MS-CR-TS data set + split: str, in [all | train | val | test] + region: str, [all | africa | america | asiaEast | asiaWest | europa] + cloud_masks: str, type of cloud mask detector to run on optical data, in [] + sample_type: str, [generic | cloudy_cloudfree] + n_input_samples: int, number of input samples in time series + rescale_method: str, [default | resnet] + + OUT: + data_loader: SEN12MSCRTS instance, implements an iterator that can be traversed via __getitem__(pdx), + which returns the pdx-th dictionary of patch-samples (whose structure depends on sample_type) +""" + + +class SEN12MSCR(Dataset): + def __init__( + self, + root, + split="all", + region="all", + cloud_masks="s2cloudless_mask", + sample_type="pretrain", + rescale_method="default", + ): + self.root_dir = root # set root directory which contains all ROI + self.region = region # region according to which the ROI are selected + if self.region != "all": + raise NotImplementedError # TODO: currently only supporting 'all' + self.ROI = { + "ROIs1158": ["106"], + "ROIs1868": [ + "17", + "36", + "56", + "73", + "85", + "100", + "114", + "119", + "121", + "126", + "127", + "139", + "142", + "143", + ], + "ROIs1970": [ + "20", + "21", + "35", + "40", + "57", + "65", + "71", + "82", + "83", + "91", + "112", + "116", + "119", + "128", + "132", + "133", + "135", + "139", + "142", + "144", + "149", + ], + "ROIs2017": [ + "8", + "22", + "25", + "32", + "49", + "61", + "63", + "69", + "75", + "103", + "108", + "115", + "116", + "117", + "130", + "140", + "146", + ], + } + + # define splits conform with SEN12MS-CR-TS + self.splits = {} + self.splits["train"] = [ + "ROIs1970_fall_s1/s1_3", + "ROIs1970_fall_s1/s1_22", + "ROIs1970_fall_s1/s1_148", + "ROIs1970_fall_s1/s1_107", + "ROIs1970_fall_s1/s1_1", + "ROIs1970_fall_s1/s1_114", + "ROIs1970_fall_s1/s1_135", + "ROIs1970_fall_s1/s1_40", + "ROIs1970_fall_s1/s1_42", + "ROIs1970_fall_s1/s1_31", + "ROIs1970_fall_s1/s1_149", + "ROIs1970_fall_s1/s1_64", + "ROIs1970_fall_s1/s1_28", + "ROIs1970_fall_s1/s1_144", + "ROIs1970_fall_s1/s1_57", + "ROIs1970_fall_s1/s1_35", + "ROIs1970_fall_s1/s1_133", + "ROIs1970_fall_s1/s1_30", + "ROIs1970_fall_s1/s1_134", + "ROIs1970_fall_s1/s1_141", + "ROIs1970_fall_s1/s1_112", + "ROIs1970_fall_s1/s1_116", + "ROIs1970_fall_s1/s1_37", + "ROIs1970_fall_s1/s1_26", + "ROIs1970_fall_s1/s1_77", + "ROIs1970_fall_s1/s1_100", + "ROIs1970_fall_s1/s1_83", + "ROIs1970_fall_s1/s1_71", + "ROIs1970_fall_s1/s1_93", + "ROIs1970_fall_s1/s1_119", + "ROIs1970_fall_s1/s1_104", + "ROIs1970_fall_s1/s1_136", + "ROIs1970_fall_s1/s1_6", + "ROIs1970_fall_s1/s1_41", + "ROIs1970_fall_s1/s1_125", + "ROIs1970_fall_s1/s1_91", + "ROIs1970_fall_s1/s1_131", + "ROIs1970_fall_s1/s1_120", + "ROIs1970_fall_s1/s1_110", + "ROIs1970_fall_s1/s1_19", + "ROIs1970_fall_s1/s1_14", + "ROIs1970_fall_s1/s1_81", + "ROIs1970_fall_s1/s1_39", + "ROIs1970_fall_s1/s1_109", + "ROIs1970_fall_s1/s1_33", + "ROIs1970_fall_s1/s1_88", + "ROIs1970_fall_s1/s1_11", + "ROIs1970_fall_s1/s1_128", + "ROIs1970_fall_s1/s1_142", + "ROIs1970_fall_s1/s1_122", + "ROIs1970_fall_s1/s1_4", + "ROIs1970_fall_s1/s1_27", + "ROIs1970_fall_s1/s1_147", + "ROIs1970_fall_s1/s1_85", + "ROIs1970_fall_s1/s1_82", + "ROIs1970_fall_s1/s1_105", + "ROIs1158_spring_s1/s1_9", + "ROIs1158_spring_s1/s1_1", + "ROIs1158_spring_s1/s1_124", + "ROIs1158_spring_s1/s1_40", + "ROIs1158_spring_s1/s1_101", + "ROIs1158_spring_s1/s1_21", + "ROIs1158_spring_s1/s1_134", + "ROIs1158_spring_s1/s1_145", + "ROIs1158_spring_s1/s1_141", + "ROIs1158_spring_s1/s1_66", + "ROIs1158_spring_s1/s1_8", + "ROIs1158_spring_s1/s1_26", + "ROIs1158_spring_s1/s1_77", + "ROIs1158_spring_s1/s1_113", + "ROIs1158_spring_s1/s1_100", + "ROIs1158_spring_s1/s1_117", + "ROIs1158_spring_s1/s1_119", + "ROIs1158_spring_s1/s1_6", + "ROIs1158_spring_s1/s1_58", + "ROIs1158_spring_s1/s1_120", + "ROIs1158_spring_s1/s1_110", + "ROIs1158_spring_s1/s1_126", + "ROIs1158_spring_s1/s1_115", + "ROIs1158_spring_s1/s1_121", + "ROIs1158_spring_s1/s1_39", + "ROIs1158_spring_s1/s1_109", + "ROIs1158_spring_s1/s1_63", + "ROIs1158_spring_s1/s1_75", + "ROIs1158_spring_s1/s1_132", + "ROIs1158_spring_s1/s1_128", + "ROIs1158_spring_s1/s1_142", + "ROIs1158_spring_s1/s1_15", + "ROIs1158_spring_s1/s1_45", + "ROIs1158_spring_s1/s1_97", + "ROIs1158_spring_s1/s1_147", + "ROIs1868_summer_s1/s1_90", + "ROIs1868_summer_s1/s1_87", + "ROIs1868_summer_s1/s1_25", + "ROIs1868_summer_s1/s1_124", + "ROIs1868_summer_s1/s1_114", + "ROIs1868_summer_s1/s1_135", + "ROIs1868_summer_s1/s1_40", + "ROIs1868_summer_s1/s1_101", + "ROIs1868_summer_s1/s1_42", + "ROIs1868_summer_s1/s1_31", + "ROIs1868_summer_s1/s1_36", + "ROIs1868_summer_s1/s1_139", + "ROIs1868_summer_s1/s1_56", + "ROIs1868_summer_s1/s1_133", + "ROIs1868_summer_s1/s1_55", + "ROIs1868_summer_s1/s1_43", + "ROIs1868_summer_s1/s1_113", + "ROIs1868_summer_s1/s1_76", + "ROIs1868_summer_s1/s1_123", + "ROIs1868_summer_s1/s1_143", + "ROIs1868_summer_s1/s1_93", + "ROIs1868_summer_s1/s1_125", + "ROIs1868_summer_s1/s1_89", + "ROIs1868_summer_s1/s1_120", + "ROIs1868_summer_s1/s1_126", + "ROIs1868_summer_s1/s1_72", + "ROIs1868_summer_s1/s1_115", + "ROIs1868_summer_s1/s1_121", + "ROIs1868_summer_s1/s1_146", + "ROIs1868_summer_s1/s1_140", + "ROIs1868_summer_s1/s1_95", + "ROIs1868_summer_s1/s1_102", + "ROIs1868_summer_s1/s1_7", + "ROIs1868_summer_s1/s1_11", + "ROIs1868_summer_s1/s1_132", + "ROIs1868_summer_s1/s1_15", + "ROIs1868_summer_s1/s1_137", + "ROIs1868_summer_s1/s1_4", + "ROIs1868_summer_s1/s1_27", + "ROIs1868_summer_s1/s1_147", + "ROIs1868_summer_s1/s1_86", + "ROIs1868_summer_s1/s1_47", + "ROIs2017_winter_s1/s1_68", + "ROIs2017_winter_s1/s1_25", + "ROIs2017_winter_s1/s1_62", + "ROIs2017_winter_s1/s1_135", + "ROIs2017_winter_s1/s1_42", + "ROIs2017_winter_s1/s1_64", + "ROIs2017_winter_s1/s1_21", + "ROIs2017_winter_s1/s1_55", + "ROIs2017_winter_s1/s1_112", + "ROIs2017_winter_s1/s1_116", + "ROIs2017_winter_s1/s1_8", + "ROIs2017_winter_s1/s1_59", + "ROIs2017_winter_s1/s1_49", + "ROIs2017_winter_s1/s1_104", + "ROIs2017_winter_s1/s1_81", + "ROIs2017_winter_s1/s1_146", + "ROIs2017_winter_s1/s1_75", + "ROIs2017_winter_s1/s1_94", + "ROIs2017_winter_s1/s1_102", + "ROIs2017_winter_s1/s1_61", + "ROIs2017_winter_s1/s1_47", + "ROIs1868_summer_s1/s1_100", # note: this ROI is also used for testing in SEN12MS-CR-TS. If you wish to combine both datasets, please comment out this line + ] + self.splits["val"] = [ + "ROIs2017_winter_s1/s1_22", + "ROIs1868_summer_s1/s1_19", + "ROIs1970_fall_s1/s1_65", + "ROIs1158_spring_s1/s1_17", + "ROIs2017_winter_s1/s1_107", + "ROIs1868_summer_s1/s1_80", + "ROIs1868_summer_s1/s1_127", + "ROIs2017_winter_s1/s1_130", + "ROIs1868_summer_s1/s1_17", + "ROIs2017_winter_s1/s1_84", + ] + self.splits["test"] = [ + "ROIs1158_spring_s1/s1_106", + "ROIs1158_spring_s1/s1_123", + "ROIs1158_spring_s1/s1_140", + "ROIs1158_spring_s1/s1_31", + "ROIs1158_spring_s1/s1_44", + "ROIs1868_summer_s1/s1_119", + "ROIs1868_summer_s1/s1_73", + "ROIs1970_fall_s1/s1_139", + "ROIs2017_winter_s1/s1_108", + "ROIs2017_winter_s1/s1_63", + ] + + self.splits["all"] = ( + self.splits["train"] + self.splits["test"] + self.splits["val"] + ) + self.split = split + + assert split in [ + "all", + "train", + "val", + "test", + ], "Input dataset must be either assigned as all, train, test, or val!" + assert sample_type in ["pretrain"], "Input data must be pretrain!" + assert cloud_masks in [ + None, + "cloud_cloudshadow_mask", + "s2cloudless_map", + "s2cloudless_mask", + ], "Unknown cloud mask type!" + + self.modalities = ["S1", "S2"] + self.cloud_masks = cloud_masks # e.g. 'cloud_cloudshadow_mask', 's2cloudless_map', 's2cloudless_mask' + self.sample_type = sample_type # e.g. 'pretrain' + + self.time_points = range(1) + self.n_input_t = 1 # specifies the number of samples, if only part of the time series is used as an input + + if self.cloud_masks in ["s2cloudless_map", "s2cloudless_mask"]: + self.cloud_detector = S2PixelCloudDetector( + threshold=0.4, all_bands=True, average_over=4, dilation_size=2 + ) + else: + self.cloud_detector = None + + self.paths = self.get_paths() + self.n_samples = len(self.paths) + + # raise a warning if no data has been found + if not self.n_samples: + self.throw_warn() + + self.method = rescale_method + + # indexes all patches contained in the current data split + def get_paths( + self, + ): # assuming for the same ROI+num, the patch numbers are the same + print(f"\nProcessing paths for {self.split} split of region {self.region}") + + paths = [] + seeds_S1 = natsorted( + [s1dir for s1dir in os.listdir(self.root_dir) if "_s1" in s1dir] + ) + for seed in seeds_S1: + rois_S1 = natsorted(os.listdir(os.path.join(self.root_dir, seed))) + for roi in rois_S1: + roi_dir = os.path.join(self.root_dir, seed, roi) + paths_S1 = natsorted( + [os.path.join(roi_dir, s1patch) for s1patch in os.listdir(roi_dir)] + ) + paths_S2 = [ + patch.replace("/s1", "/s2").replace("_s1", "_s2") + for patch in paths_S1 + ] + paths_S2_cloudy = [ + patch.replace("/s1", "/s2_cloudy").replace("_s1", "_s2_cloudy") + for patch in paths_S1 + ] + + for pdx, _ in enumerate(paths_S1): + # omit patches that are potentially unpaired + if not all( + [ + os.path.isfile(paths_S1[pdx]), + os.path.isfile(paths_S2[pdx]), + os.path.isfile(paths_S2_cloudy[pdx]), + ] + ): + continue + # don't add patch if not belonging to the selected split + if not any( + [ + split_roi in paths_S1[pdx] + for split_roi in self.splits[self.split] + ] + ): + continue + sample = { + "S1": paths_S1[pdx], + "S2": paths_S2[pdx], + "S2_cloudy": paths_S2_cloudy[pdx], + } + paths.append(sample) + return paths + + def __getitem__(self, pdx): # get the triplet of patch with ID pdx + s1_tif = read_tif(self.paths[pdx]["S1"]) + s2_tif = read_tif(self.paths[pdx]["S2"]) + s2_cloudy_tif = read_tif(self.paths[pdx]["S2_cloudy"]) + coord = list(s2_tif.bounds) + s1 = process_SAR(read_img(s1_tif), self.method) + s2 = read_img(s2_tif) # note: pre-processing happens after cloud detection + s2_cloudy = read_img( + s2_cloudy_tif + ) # note: pre-processing happens after cloud detection + mask = ( + None + if not self.cloud_masks + else get_cloud_map(s2_cloudy, self.cloud_masks, self.cloud_detector) + ) + + sample = { + "input": { + "S1": s1, + "S2": process_MS(s2_cloudy, self.method), + "masks": mask, + "coverage": np.mean(mask), + "S1 path": os.path.join(self.root_dir, self.paths[pdx]["S1"]), + "S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2_cloudy"]), + "coord": coord, + }, + "target": { + "S2": process_MS(s2, self.method), + "S2 path": os.path.join(self.root_dir, self.paths[pdx]["S2"]), + "coord": coord, + }, + } + return sample + + def throw_warn(self): + warnings.warn( + """No data samples found! Please use the following directory structure: + + path/to/your/SEN12MSCR/directory: + ├───ROIs1158_spring_s1 + | ├─s1_1 + | | |... + | | ├─ROIs1158_spring_s1_1_p407.tif + | | |... + | ... + ├───ROIs1158_spring_s2 + | ├─s2_1 + | | |... + | | ├─ROIs1158_spring_s2_1_p407.tif + | | |... + | ... + ├───ROIs1158_spring_s2_cloudy + | ├─s2_cloudy_1 + | | |... + | | ├─ROIs1158_spring_s2_cloudy_1_p407.tif + | | |... + | ... + ... + + Note: Please arrange the dataset in a format as e.g. provided by the script dl_data.sh. + """ + ) + + def __len__(self): + # length of generated list + return self.n_samples + + +if __name__ == "__main__": + dataset = SEN12MSCR( + root="data2/SEN12MSCR", + split="all", + region="all", + cloud_masks="s2cloudless_mask", + sample_type="pretrain", + rescale_method="default", + ) + for each in dataset: + print(f"{each['input']['S1'].shape}") + print(f"{each['input']['S2'].shape}") + print(f"{each['input']['masks'].shape}") + print(f"{each['target']['S2'].shape}") + # (2, 256, 256) + # (13, 256, 256) + # (256, 256) + # (13, 256, 256) + break diff --git a/UnCRtainTS/environment.yaml b/UnCRtainTS/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ba31d958998709770a5c0d62b139d8c7a7a407e --- /dev/null +++ b/UnCRtainTS/environment.yaml @@ -0,0 +1,13 @@ +name: uncrtaints +channels: + - nvidia + - pytorch + - defaults +dependencies: + - nvidia::cudatoolkit=11.7 + - python + - pip=20.3 + - pytorch::pytorch=2.0.0 + - numpy + - pip: + diff --git a/UnCRtainTS/model/.gitignore b/UnCRtainTS/model/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dc9b1e4ede5bbb6bb492bc59cccc36b99a847d8f --- /dev/null +++ b/UnCRtainTS/model/.gitignore @@ -0,0 +1,136 @@ +# Byte-compiled / optimized / DLL files +todo.txt +__pycache__/ +*.py[cod] +*$py.class +*.swp +# C extensions +*.so +.idea/ +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg + + +.DS_Store +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# ignore particular folders and directories +./util/precomputed \ No newline at end of file diff --git a/UnCRtainTS/model/checkpoint/diffcr_bs32_epoch17/model.pth.tar b/UnCRtainTS/model/checkpoint/diffcr_bs32_epoch17/model.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..5443067aa8d2326ea5ba78c7bebdf528b6220143 --- /dev/null +++ b/UnCRtainTS/model/checkpoint/diffcr_bs32_epoch17/model.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:733e42830f19c26569427dec26ecb63bf98a19a3d48d76df041d30290e3a28c6 +size 213833726 diff --git a/UnCRtainTS/model/checkpoint/monotemporalL2/conf.json b/UnCRtainTS/model/checkpoint/monotemporalL2/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..e217740d15f9ca8a1cc1615f7edc83a9c9ad0a2b --- /dev/null +++ b/UnCRtainTS/model/checkpoint/monotemporalL2/conf.json @@ -0,0 +1,56 @@ +{ + "model": "uncrtaints", + "encoder_widths": [ + 128 + ], + "decoder_widths": [ + 128, + 128, + 128, + 128, + 128 + ], + "out_conv": [ + 13 + ], + "mean_nonLinearity": true, + "var_nonLinearity": "softplus", + "use_sar": true, + "agg_mode": "att_group", + "encoder_norm": "group", + "decoder_norm": "batch", + "n_head": 1, + "d_model": 256, + "use_v": false, + "positional_encoding": true, + "d_k": 4, + "res_dir": "./results", + "experiment_name": "monotemporalL2", + "device": "cuda", + "display_step": 10, + "batch_size": 4, + "lr": 0.001, + "gamma": 0.8, + "ref_date": "2014-04-03", + "pad_value": 0, + "padding_mode": "reflect", + "val_every": 1, + "val_after": 0, + "pretrain": true, + "input_t": 1, + "sample_type": "pretrain", + "vary_samples": true, + "min_cov": 0.0, + "max_cov": 1.0, + "region": "all", + "max_samples": 1000000000, + "input_size": 256, + "plot_every": -1, + "loss": "l2", + "covmode": "diag", + "scale_by": 10.0, + "separate_out": false, + "resume_from": false, + "epochs": 20, + "trained_checkp": "" +} diff --git a/UnCRtainTS/model/checkpoint/monotemporalL2/model.pth.tar b/UnCRtainTS/model/checkpoint/monotemporalL2/model.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..bf1caf160c7994f0872b017c945a48bdb43517d0 --- /dev/null +++ b/UnCRtainTS/model/checkpoint/monotemporalL2/model.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b86c07efb75812e5f9564da3c6ad289be5f82d86f6aa837df37f48b01b11aabc +size 6825365 diff --git a/UnCRtainTS/model/ensemble_reconstruct.py b/UnCRtainTS/model/ensemble_reconstruct.py new file mode 100644 index 0000000000000000000000000000000000000000..98a9734a95fe03ae2a62ebbd48a08b36d207a265 --- /dev/null +++ b/UnCRtainTS/model/ensemble_reconstruct.py @@ -0,0 +1,180 @@ +""" + Python script to obtain Deep Ensemble predictions by collecting each instance's pre-computed predictions. + Each member's predictions are first meant to be pre-computed via test_reconstruct.py, with the outputs exported, + and read again in this script. Online ensembling is currently not implemented as this may exceed hardware constraints. + For every ensemble member, the path to its output directory has to be specified in the list 'ensemble_paths'. +""" + +import os +import sys +import torch +import numpy as np +from tqdm import tqdm +from natsort import natsorted + +dirname = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(dirname)) + +from data.dataLoader import SEN12MSCR, SEN12MSCRTS +from src.learning.metrics import img_metrics, avg_img_metrics +from train_reconstruct import recursive_todevice, compute_uce_auce, export, plot_img, save_results + +epoch = 1 +root = '/home/data/' # path to directory containing dataset +mode = 'test' # split to evaluate on +in_time = 3 # length of input time series +region = 'all' # region of areas of interest +max_samples = 1e9 # maximum count of samples to consider +uncertainty = 'both' # e.g. 'aleatoric', 'epistemic', 'both' --- only matters if ensemble==True +ensemble = True # whether to compute ensemble mean and var or not +pixelwise = True # whether to summarize errors and variances for image-based AUCE and UCE or keep pixel-based statistics +export_path = None # where to export ensemble statistics, set to None if no writing to files is desired + +# define path to find the individual ensembe member's predictions in +ensemble_paths = [os.path.join(dirname, 'inference', f'diagonal_1/export/epoch_{epoch}/{mode}'), + os.path.join(dirname, 'inference', f'diagonal_2/export/epoch_{epoch}/{mode}'), + os.path.join(dirname, 'inference', f'diagonal_3/export/epoch_{epoch}/{mode}'), + os.path.join(dirname, 'inference', f'diagonal_4/export/epoch_{epoch}/{mode}'), + os.path.join(dirname, 'inference', f'diagonal_5/export/epoch_{epoch}/{mode}'), + ] + +n_ensemble = len(ensemble_paths) +print('Ensembling over model predictions:') +for instance in ensemble_paths: print(instance) + +if export_path: + plot_dir = os.path.join(export_path, 'plots', f'epoch_{epoch}', f'{mode}') + export_dir = os.path.join(export_path, 'export', f'epoch_{epoch}', f'{mode}') + + +def prepare_data_multi(batch, device, batch_size=1, use_sar=True): + in_S2 = recursive_todevice(torch.tensor(batch['input']['S2']), device) + in_S2_td = recursive_todevice(torch.tensor(batch['input']['S2 TD']), device) + if batch_size>1: in_S2_td = torch.stack((in_S2_td)).T + in_m = recursive_todevice(torch.tensor(batch['input']['masks']), device) + target_S2 = recursive_todevice(torch.tensor(batch['target']['S2']), device) + y = target_S2 + + if use_sar: + in_S1 = recursive_todevice(torch.tensor(batch['input']['S1']), device) + in_S1_td = recursive_todevice(torch.tensor(batch['input']['S1 TD']), device) + if batch_size>1: in_S1_td = torch.stack((in_S1_td)).T + x = torch.cat((torch.stack(in_S1,dim=1), torch.stack(in_S2,dim=1)),dim=2) + dates = torch.stack((torch.tensor(in_S1_td),torch.tensor(in_S2_td))).float().mean(dim=0).to(device) + else: + x = in_S2 # torch.stack(in_S2,dim=1) + dates = torch.tensor(in_S2_td).float().to(device) + + return x.unsqueeze(dim=0), y.unsqueeze(dim=0), in_m.unsqueeze(dim=0), dates + + +def main(): + + # list all predictions of the first ensemble member + dataPath = ensemble_paths[0] + samples = natsorted([os.path.join(dataPath, f) for f in os.listdir(dataPath) if (os.path.isfile(os.path.join(dataPath, f)) and "_pred.npy" in f)]) + + # collect sample-averaged uncertainties and errors + img_meter = avg_img_metrics() + vars_aleatoric = [] + errs, errs_se, errs_ae = [], [], [] + + import_data_path = os.path.join(os.getcwd(), 'util', 'precomputed', f'generic_{in_time}_{mode}_{region}_s2cloudless_mask.npy') + import_data_path = import_data_path if os.path.isfile(import_data_path) else None + dt_test = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=mode, region=region, sample_type="cloudy_cloudfree" , n_input_samples=in_time, import_data_path=import_data_path) + if len(dt_test.paths) != len(samples): raise AssertionError + + # iterate over the ensemble member's mean predictions + for idx, sample_mean in enumerate(tqdm(samples)): + if idx >= max_samples: break # exceeded desired sample count + + # fetch target data and cloud masks of idx-th sample from + batch = dt_test.getsample(idx) # ... in order to compute metrics + x, y, in_m, _ = prepare_data_multi(batch, 'cuda', batch_size=1, use_sar=False) + + try: + mean, var = [], [] + for path in ensemble_paths: # for each ensemble member ... + # ... load the member's mean predictions and ... + mean.append(np.load(os.path.join(path, os.path.basename(sample_mean)))) + # ... load the member's covariance or var predictions + sample_var = sample_mean.replace('_pred', '_covar') + if not os.path.isfile(os.path.join(path, os.path.basename(sample_var))): + sample_var = sample_mean.replace('_pred', '_var') + var.append(np.load(os.path.join(path, os.path.basename(sample_var)))) + except: + # skip any sample for which not all members provide predictions + # (note: we also next'ed the dataloader's sample already) + print(f'Skipped sample {idx}, missing data.') + continue + mean, var = np.array(mean), np.array(var) + + # get the variances from the covariance matrix + if len(var.shape) > 4: # loaded covariance matrix + var = np.moveaxis(np.diagonal(var, axis1=1, axis2=2), -1, 1) + + # combine predictions + + if ensemble: + # get ensemble estimate and epistemic uncertainty, + # approximate 1 Gaussian by mixture parameter ensembling + mean_ensemble = 1/n_ensemble * np.sum(mean, axis=0) + + if uncertainty == 'aleatoric': + # average the members' aleatoric uncertainties + var_ensemble = 1/n_ensemble * np.sum(var, axis=0) + elif uncertainty == 'epistemic': + # compute average variance of ensemble predictions + var_ensemble = 1/n_ensemble * np.sum(mean**2, axis=0) - mean_ensemble**2 + elif uncertainty == 'both': + # combine both + var_ensemble = 1/n_ensemble * np.sum(var + mean**2, axis=0) - mean_ensemble**2 + else: raise NotImplementedError + else: mean_ensemble, var_ensemble = mean[0], var[0] + + mean_ensemble = torch.tensor(mean_ensemble).cuda() + var_ensemble = torch.tensor(var_ensemble).cuda() + + # compute test metrics on ensemble prediction + extended_metrics = img_metrics(y[0], mean_ensemble.unsqueeze(dim=0), + var=var_ensemble.unsqueeze(dim=0), + pixelwise=pixelwise) + img_meter.add(extended_metrics) # accumulate performances over the entire split + + if pixelwise: # collect variances and errors + vars_aleatoric.extend(extended_metrics['pixelwise var']) + errs.extend(extended_metrics['pixelwise error']) + errs_se.extend(extended_metrics['pixelwise se']) + errs_ae.extend(extended_metrics['pixelwise ae']) + else: + vars_aleatoric.append(extended_metrics['mean var']) + errs.append(extended_metrics['error']) + errs_se.append(extended_metrics['mean se']) + errs_ae.append(extended_metrics['mean ae']) + + if export_path: # plot and export ensemble predictions + plot_img(mean_ensemble.unsqueeze(dim=0), 'pred', plot_dir, file_id=idx) + plot_img(x[0], 'in', plot_dir, file_id=idx) + plot_img(var_ensemble.mean(dim=0, keepdims=True).expand(3, *var_ensemble.shape[1:]).unsqueeze(dim=0), 'var', plot_dir, file_id=idx) + export(mean_ensemble[None], 'pred', export_dir, file_id=idx) + export(var_ensemble[None], 'var', export_dir, file_id=idx) + + + # compute UCE and AUCE + uce_l2, auce_l2 = compute_uce_auce(vars_aleatoric, errs, len(vars_aleatoric), percent=5, l2=True, mode=mode, step=0) + + # no need for a running mean here + img_meter.value()['UCE SE'] = uce_l2.cpu().numpy().item() + img_meter.value()['AUCE SE'] = auce_l2.cpu().numpy().item() + + print(f'{mode} split image metrics: {img_meter.value()}') + if export_path: + np.save(os.path.join(export_path, f'pred_var_{uncertainty}.npy'), vars_aleatoric) + np.save(os.path.join(export_path, 'errors.npy'), errs) + save_results(img_meter.value(), export_path, split=mode) + print(f'Exported predictions to path {export_path}') + + +if __name__ == "__main__": + main() + exit() \ No newline at end of file diff --git a/UnCRtainTS/model/inference/diffcr_bs32_epoch17/conf.json b/UnCRtainTS/model/inference/diffcr_bs32_epoch17/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..89ba8009baed583786fd65259f0f50f87c6833a5 --- /dev/null +++ b/UnCRtainTS/model/inference/diffcr_bs32_epoch17/conf.json @@ -0,0 +1,73 @@ +{ + "model": "uncrtaints", + "experiment_name": "diffcr_bs32_epoch17", + "res_dir": "./inference", + "plot_every": -1, + "export_every": 1, + "resume_at": -1, + "encoder_widths": [ + 128 + ], + "decoder_widths": [ + 128, + 128, + 128, + 128, + 128 + ], + "out_conv": [ + 13 + ], + "mean_nonLinearity": true, + "var_nonLinearity": "softplus", + "agg_mode": "att_group", + "encoder_norm": "group", + "decoder_norm": "batch", + "block_type": "mbconv", + "padding_mode": "reflect", + "pad_value": 0, + "n_head": 16, + "d_model": 256, + "positional_encoding": true, + "d_k": 4, + "low_res_size": 32, + "use_v": false, + "num_workers": 0, + "rdm_seed": 1, + "device": "cuda:6", + "display_step": 10, + "loss": "MGNLL", + "resume_from": false, + "unfreeze_after": 0, + "epochs": 20, + "batch_size": 32, + "chunk_size": null, + "lr": 0.01, + "gamma": 1.0, + "val_every": 1, + "val_after": 0, + "use_sar": true, + "pretrain": true, + "input_t": 1, + "ref_date": "2014-04-03", + "sample_type": "pretrain", + "vary_samples": true, + "min_cov": 0.0, + "max_cov": 1.0, + "root1": "/home/data/SEN12MSCRTS", + "root2": "/home/data/SEN12MSCRTS", + "root3": "data2/SEN12MSCR", + "precomputed": "/home/code/UnCRtainTS/util/precomputed", + "region": "all", + "max_samples_count": 1000000000, + "max_samples_frac": 1.0, + "profile": false, + "trained_checkp": "", + "covmode": "diag", + "scale_by": 1.0, + "separate_out": false, + "weight_folder": "checkpoint/", + "use_custom": false, + "load_config": "", + "pid": 2049339 +} \ No newline at end of file diff --git a/UnCRtainTS/model/inference/monotemporalL2/conf.json b/UnCRtainTS/model/inference/monotemporalL2/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..a5744dffaca9bd8ad722bacf69c0810771c083ec --- /dev/null +++ b/UnCRtainTS/model/inference/monotemporalL2/conf.json @@ -0,0 +1,75 @@ +{ + "model": "uncrtaints", + "encoder_widths": [ + 128 + ], + "decoder_widths": [ + 128, + 128, + 128, + 128, + 128 + ], + "out_conv": [ + 13 + ], + "mean_nonLinearity": true, + "var_nonLinearity": "softplus", + "use_sar": true, + "agg_mode": "att_group", + "encoder_norm": "group", + "decoder_norm": "batch", + "n_head": 1, + "d_model": 256, + "use_v": false, + "positional_encoding": true, + "d_k": 4, + "experiment_name": "monotemporalL2", + "lr": 0.001, + "gamma": 0.8, + "ref_date": "2014-04-03", + "pad_value": 0, + "padding_mode": "reflect", + "val_every": 1, + "val_after": 0, + "pretrain": true, + "sample_type": "pretrain", + "vary_samples": true, + "max_samples": 1000000000, + "input_size": 256, + "loss": "l2", + "covmode": "diag", + "scale_by": 10.0, + "separate_out": false, + "resume_from": false, + "epochs": 20, + "res_dir": "./inference", + "plot_every": -1, + "export_every": 1, + "resume_at": -1, + "device": "cuda:6", + "display_step": 10, + "batch_size": 128, + "input_t": 1, + "min_cov": 0.0, + "max_cov": 1.0, + "root1": "/home/data/SEN12MSCRTS", + "root2": "/home/data/SEN12MSCRTS", + "root3": "data2/SEN12MSCR", + "region": "all", + "max_samples_count": 1000000000, + "trained_checkp": "", + "weight_folder": "checkpoint/", + "pid": 2973877, + "block_type": "mbconv", + "low_res_size": 32, + "num_workers": 0, + "rdm_seed": 1, + "unfreeze_after": 0, + "chunk_size": null, + "precomputed": "/home/code/UnCRtainTS/util/precomputed", + "max_samples_frac": 1.0, + "profile": false, + "use_custom": false, + "load_config": "" +} \ No newline at end of file diff --git a/UnCRtainTS/model/inference/monotemporalL2/test_metrics.json b/UnCRtainTS/model/inference/monotemporalL2/test_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..5b5c48bf161e3944767db8e06349c9da718a350b --- /dev/null +++ b/UnCRtainTS/model/inference/monotemporalL2/test_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.038980784788900186, + "MAE": 0.02744151706378001, + "PSNR": 28.900039257648842, + "SAM": 8.320397798952893, + "SSIM": 0.8797316785507024, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/parse_args.py b/UnCRtainTS/model/parse_args.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb4c2bea2531427fefbce737e90c2785770588a --- /dev/null +++ b/UnCRtainTS/model/parse_args.py @@ -0,0 +1,95 @@ +import argparse + +S2_BANDS = 13 + +def create_parser(mode='train'): + parser = argparse.ArgumentParser() + # model parameters + parser.add_argument( + "--model", + default='uncrtaints', # e.g. 'unet', 'utae', 'uncrtaints', + type=str, + help="Type of architecture to use. Can be one of: (utae/unet3d/fpn/convlstm/convgru/uconvlstm/buconvlstm)", + ) + parser.add_argument("--experiment_name", default='my_first_experiment', help="Name of the current experiment",) + + # fast switching between default arguments, depending on train versus test mode + if mode=='train': + parser.add_argument("--res_dir", default="./results", help="Path to where the results are stored, e.g. ./results for training or ./inference for testing",) + parser.add_argument("--plot_every", default=-1, type=int, help="Interval (in items) of exporting plots at validation or test time. Set -1 to disable") + parser.add_argument("--export_every", default=-1, type=int, help="Interval (in items) of exporting data at validation or test time. Set -1 to disable") + parser.add_argument("--resume_at", default=0, type=int, help="Epoch to resume training from (may re-weight --lr in the optimizer) or epoch to load checkpoint from at test time") + elif mode=='test': + parser.add_argument("--res_dir", default="./inference", type=str, help="Path to directory where results are written.") + parser.add_argument("--plot_every", default=-1, type=int, help="Interval (in items) of exporting plots at validation or test time. Set -1 to disable") + parser.add_argument("--export_every", default=1, type=int, help="Interval (in items) of exporting data at validation or test time. Set -1 to disable") + parser.add_argument("--resume_at", default=-1, type=int, help="Epoch to load checkpoint from and run testing with (use -1 for best on validation split)") + + parser.add_argument("--encoder_widths", default="[128]", type=str, help="e.g. [64,64,64,128] for U-TAE or [128] for UnCRtainTS") + parser.add_argument("--decoder_widths", default="[128,128,128,128,128]", type=str, help="e.g. [64,64,64,128] for U-TAE or [128,128,128,128,128] for UnCRtainTS") + parser.add_argument("--out_conv", default=f"[{S2_BANDS}]", help="output CONV, note: if inserting another layer then consider treating normalizations separately") + parser.add_argument("--mean_nonLinearity", dest="mean_nonLinearity", action="store_false", help="whether to apply a sigmoidal output nonlinearity to the mean prediction") + parser.add_argument("--var_nonLinearity", default="softplus", type=str, help="how to squash the network's variance outputs [relu | softplus | elu ]") + parser.add_argument("--agg_mode", default="att_group", type=str, help="type of temporal aggregation in L-TAE module") + parser.add_argument("--encoder_norm", default="group", type=str, help="e.g. 'group' (when using many channels) or 'instance' (for few channels)") + parser.add_argument("--decoder_norm", default="batch", type=str, help="e.g. 'group' (when using many channels) or 'instance' (for few channels)") + parser.add_argument("--block_type", default="mbconv", type=str, help="type of CONV block to use [residual | mbconv]") + parser.add_argument("--padding_mode", default="reflect", type=str) + parser.add_argument("--pad_value", default=0, type=float) + + # attention-specific parameters + parser.add_argument("--n_head", default=16, type=int, help="default value of 16, 4 for debugging") + parser.add_argument("--d_model", default=256, type=int, help="layers in L-TAE, default value of 256") + parser.add_argument("--positional_encoding", dest="positional_encoding", action="store_false", help="whether to use positional encoding or not") + parser.add_argument("--d_k", default=4, type=int) + parser.add_argument("--low_res_size", default=32, type=int, help="resolution to downsample to") + parser.add_argument("--use_v", dest="use_v", action="store_true", help="whether to use values v or not") + + # set-up parameters + parser.add_argument("--num_workers", default=0, type=int, help="Number of data loading workers") + parser.add_argument("--rdm_seed", default=1, type=int, help="Random seed") + parser.add_argument("--device",default="cuda",type=str,help="Name of device to use for tensor computations (cuda/cpu)",) + parser.add_argument("--display_step", default=10, type=int, help="Interval in batches between display of training metrics",) + + # training parameters + parser.add_argument("--loss", default="MGNLL", type=str, help="Image reconstruction loss to utilize [l1|l2|GNLL|MGNLL].") + parser.add_argument("--resume_from", dest="resume_from", action="store_true", help="resume training acc. to JSON in --experiment_name and *.pth chckp in --trained_checkp") + parser.add_argument("--unfreeze_after", default=0, type=int, help="When to unfreeze ALL weights for training") + parser.add_argument("--epochs", default=20, type=int, help="Number of epochs to train") + parser.add_argument("--batch_size", default=4, type=int, help="Batch size") + parser.add_argument("--chunk_size", type=int, help="Size of vmap batches, this can be adjusted to accommodate for additional memory needs") + parser.add_argument("--lr", default=1e-2, type=float, help="Learning rate, e.g. 0.01") + parser.add_argument("--gamma", default=1.0, type=float, help="Learning rate decay parameter for scheduler") + parser.add_argument("--val_every", default=1, type=int, help="Interval in epochs between two validation steps.") + parser.add_argument("--val_after", default=0, type=int, help="Do validation only after that many epochs.") + + # flags specific to SEN12MS-CR and SEN12MS-CR-TS + parser.add_argument("--use_sar", dest="use_sar", action="store_true", help="whether to use SAR or not") + parser.add_argument("--pretrain", dest="pretrain", action="store_true", help="whether to perform pretraining on SEN12MS-CR or training on SEN12MS-CR-TS") + parser.add_argument("--input_t", default=3, type=int, help="number of input time points to sample, unet3d needs at least 4 time points") + parser.add_argument("--ref_date", default="2014-04-03", type=str, help="reference date for Sentinel observations") + parser.add_argument("--sample_type", default="cloudy_cloudfree", type=str, help="type of samples returned [cloudy_cloudfree | generic]") + parser.add_argument("--vary_samples", dest="vary_samples", action="store_false", help="whether to sample different time points across epochs or not") + parser.add_argument("--min_cov", default=0.0, type=float, help="The minimum cloud coverage to accept per input sample at train time. Gets overwritten by --vary_samples") + parser.add_argument("--max_cov", default=1.0, type=float, help="The maximum cloud coverage to accept per input sample at train time. Gets overwritten by --vary_samples") + parser.add_argument("--root1", default='/home/data/SEN12MSCRTS', type=str, help="path to your copy of SEN12MS-CR-TS") + parser.add_argument("--root2", default='/home/data/SEN12MSCRTS', type=str, help="path to your copy of SEN12MS-CR-TS validation & test splits") + parser.add_argument("--root3", default='/home/data/SEN12MSCR', type=str, help="path to your copy of SEN12MS-CR for pretraining") + parser.add_argument("--precomputed", default='/home/code/UnCRtainTS/util/precomputed', type=str, help="path to pre-computed cloud statistics") + parser.add_argument("--region", default="all", type=str, help="region to (sub-)sample ROI from [all|africa|america|asiaEast|asiaWest|europa]") + parser.add_argument("--max_samples_count", default=int(1e9), type=int, help="count of data (sub-)samples to take") + parser.add_argument("--max_samples_frac", default=1.0, type=float, help="fraction of data (sub-)samples to take") + parser.add_argument("--profile", dest="profile", action="store_true", help="whether to profile code or not") + parser.add_argument("--trained_checkp", default="", type=str, help="Path to loading a pre-trained network *.pth file, rather than initializing weights randomly") + + # flags specific to uncertainty modeling + parser.add_argument("--covmode", default='diag', type=str, help="covariance matrix type [uni|iso|diag].") + parser.add_argument("--scale_by", default=1.0, type=float, help="rescale data within model, e.g. to [0,10]") + parser.add_argument("--separate_out", dest="separate_out", action="store_true", help="whether to separately process mean and variance predictions or in a shared layer") + + # flags specific for testing + parser.add_argument("--weight_folder", type=str, default="./results", help="Path to the main folder containing the pre-trained weights") + parser.add_argument("--use_custom", dest="use_custom", action="store_true", help="whether to test on individually specified patches or not") + parser.add_argument("--load_config", default='', type=str, help="path of conf.json file to load") + + return parser \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/conf.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..66d5736a5d4630d3803bf9e13ff1b03a2d8890eb --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/conf.json @@ -0,0 +1,74 @@ +{ + "model": "uncrtaints", + "experiment_name": "diffcr_bs32_lr15e-4", + "res_dir": "./results", + "plot_every": -1, + "export_every": -1, + "resume_at": 0, + "encoder_widths": [ + 128 + ], + "decoder_widths": [ + 128, + 128, + 128, + 128, + 128 + ], + "out_conv": [ + 13 + ], + "mean_nonLinearity": false, + "var_nonLinearity": "softplus", + "agg_mode": "att_group", + "encoder_norm": "group", + "decoder_norm": "batch", + "block_type": "mbconv", + "padding_mode": "reflect", + "pad_value": 0.0, + "n_head": 1, + "d_model": 256, + "positional_encoding": false, + "d_k": 4, + "low_res_size": 32, + "use_v": false, + "num_workers": 16, + "rdm_seed": 1, + "device": "cuda:0", + "display_step": 10, + "loss": "l2", + "resume_from": false, + "unfreeze_after": 0, + "epochs": 100, + "batch_size": 32, + "chunk_size": null, + "lr": 0.0005, + "gamma": 0.8, + "val_every": 1, + "val_after": 0, + "use_sar": true, + "pretrain": true, + "input_t": 1, + "ref_date": "2014-04-03", + "sample_type": "pretrain", + "vary_samples": false, + "min_cov": 0.0, + "max_cov": 1.0, + "root1": "/home/data/SEN12MSCRTS", + "root2": "/home/data/SEN12MSCRTS", + "root3": "data2/SEN12MSCR", + "precomputed": "/home/code/UnCRtainTS/util/precomputed", + "region": "all", + "max_samples_count": 1000000000, + "max_samples_frac": 1.0, + "profile": false, + "trained_checkp": "", + "covmode": "diag", + "scale_by": 10.0, + "separate_out": false, + "weight_folder": "./results", + "use_custom": false, + "load_config": "", + "pid": 2877152, + "N_params": 19322381 +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model.pth.tar b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..63561e8b33a153ddf9c6a7a4d2543b93a143075a --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35fd168203bc0ba6bf830851778d4958658be218c9b17e047592737dc68e49b2 +size 213825786 diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_11.pth.tar b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_11.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..67b163dcf23b7f5ba7bed303133a3cc42cfb4473 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_11.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1c180b52ca34e85f3f29c3296b179723975c0326f6f3fc5efb3ea885badfdb92 +size 213833726 diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_36.pth.tar b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_36.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..fcf3205fc99c56bd59aba2d93bbdf346a5f9a926 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/model_epoch_36.pth.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6b156aee172142a2e5a530f7cef87675b8871985a410aefd03ae11004260a58 +size 213833726 diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_10_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_10_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..0693c53620f3c74fe6d257a8ce1241f69efd3295 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_10_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.029363970156402505, + "MAE": 0.020022216210124334, + "PSNR": 31.572881788434003, + "SAM": 5.883645729394377, + "SSIM": 0.8995029243551826, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_11_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_11_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..3b95e2ba61903adec486b9acce3f2d57e9334d06 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_11_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.028865320518000174, + "MAE": 0.019421607403492344, + "PSNR": 31.768579506695946, + "SAM": 5.820518317864266, + "SSIM": 0.9015902346229956, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_12_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_12_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..b17ea26e5f5d6f96c86eadc2de89aacab97ea0d5 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_12_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.028450740317855727, + "MAE": 0.019143247260043655, + "PSNR": 31.91159520461337, + "SAM": 5.6859444906808125, + "SSIM": 0.9034519373110392, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_13_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_13_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..ee3aaade57912b2cf630392ca479fa02e6e93887 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_13_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.02858291424103222, + "MAE": 0.01927231159917463, + "PSNR": 31.870293692465435, + "SAM": 5.595422786736038, + "SSIM": 0.9039587730968367, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_14_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_14_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..410a9b44c3419011bdd33dbca964a47b96b19bdd --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_14_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.028029547319395425, + "MAE": 0.018867645016614896, + "PSNR": 32.0631246642833, + "SAM": 5.500893342840581, + "SSIM": 0.9053583295002058, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_15_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_15_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..d060f105d89af079fe5f70b13b3124045439084f --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_15_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.028002896943423252, + "MAE": 0.018806874698693236, + "PSNR": 32.124292399735644, + "SAM": 5.547215727374945, + "SSIM": 0.9059335617932202, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_16_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_16_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..8b627190f9fe1f7f68308db64dbaee148f5a6143 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_16_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027591812040673, + "MAE": 0.018454319715277293, + "PSNR": 32.246636826172086, + "SAM": 5.4618273344745045, + "SSIM": 0.9069541601278078, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_17_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_17_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..8a0789149723fe804ef1231938da80946a7e861e --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_17_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027644014190348384, + "MAE": 0.018551536467928936, + "PSNR": 32.24370574226725, + "SAM": 5.461317134667159, + "SSIM": 0.9069266391564054, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_18_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_18_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..06ce4bf83e99b8389aa2624eb7c1d2cc126a2939 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_18_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027595730930196957, + "MAE": 0.018533841286411494, + "PSNR": 32.25086082007444, + "SAM": 5.458732208793296, + "SSIM": 0.9071483427338218, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_19_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_19_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..be3fe55fc60d94c6ca9df3e9ebd379af812e4204 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_19_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027390471058173046, + "MAE": 0.018348841181363623, + "PSNR": 32.342438679101114, + "SAM": 5.419562214702756, + "SSIM": 0.9077709371763801, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_1_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_1_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..1419d9ae76105d05f28821f8ac188a82ec94c6a4 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_1_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.044388687314263, + "MAE": 0.03238296521911765, + "PSNR": 27.6217471128898, + "SAM": 10.4182609624991, + "SSIM": 0.763119293530315, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_20_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_20_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..1f2da820a793c4ac1f8a062f9486a8f30e8aae4a --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_20_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.02744433999974496, + "MAE": 0.018402347492774685, + "PSNR": 32.3317267860638, + "SAM": 5.421831781584368, + "SSIM": 0.9078204176338086, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_21_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_21_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..7db3b095f01fb8f32577b226e174ca7d0a48d784 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_21_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.02734196908368276, + "MAE": 0.018271456095607545, + "PSNR": 32.37505691355452, + "SAM": 5.414982792168764, + "SSIM": 0.9081676423451513, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_22_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_22_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..ed598b3b4ca4ac7c5702e7008efe8396dd09b18d --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_22_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027245160590102936, + "MAE": 0.018224454178698474, + "PSNR": 32.41090423394659, + "SAM": 5.398054870446574, + "SSIM": 0.9083186444307726, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_23_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_23_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..2ecf476d094f18ad2a28e2b777fd5b96262148fa --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_23_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.02729571228139796, + "MAE": 0.01828788889602147, + "PSNR": 32.395462209747514, + "SAM": 5.393570169404665, + "SSIM": 0.9081527106459147, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_24_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_24_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..562f71bc638fff09b7ae96b0c68ed50646868daa --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_24_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027244927966896305, + "MAE": 0.018264471261733865, + "PSNR": 32.41291005914153, + "SAM": 5.392408405063824, + "SSIM": 0.908316025697424, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_25_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_25_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..c9d2b8b88b3215e80a0613e2143b60066d25e64c --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_25_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027284221707286997, + "MAE": 0.018280692685820887, + "PSNR": 32.40145441845609, + "SAM": 5.39122454551732, + "SSIM": 0.9083838710484244, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_26_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_26_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..e5153997644ada4d2c13ddea06b8f3be9d9601f4 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_26_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.02726639931698313, + "MAE": 0.01826652071185466, + "PSNR": 32.41051288700818, + "SAM": 5.3847173492672225, + "SSIM": 0.9083935374216957, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_27_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_27_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..c857ac784a4935d829d643277545a608aad13f40 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_27_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027239437090096214, + "MAE": 0.018245351712652874, + "PSNR": 32.42091501759734, + "SAM": 5.386129859399605, + "SSIM": 0.9085140088899692, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_28_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_28_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..fdc6893bb6a6f1e56c0a5045b1e9c5092845e615 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_28_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027215027071061892, + "MAE": 0.018221582465008372, + "PSNR": 32.42838631419994, + "SAM": 5.384088766455628, + "SSIM": 0.9085492332580157, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_29_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_29_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..0b40569f110db208b5c6e6c9571d8165a5ede373 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_29_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027242751470667303, + "MAE": 0.01824946738290542, + "PSNR": 32.41849800725127, + "SAM": 5.381040883963731, + "SSIM": 0.9085034823882493, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_2_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_2_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..2af6911886945d2609905fdf99c48b3341f6120b --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_2_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.04118385424327889, + "MAE": 0.029236674496597642, + "PSNR": 28.28133535765438, + "SAM": 8.635967488469356, + "SSIM": 0.8477815685430351, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_30_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_30_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..e221ed654d96801a0a1bce56bcb7bd3701ecc85b --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_30_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027199031498741918, + "MAE": 0.01820931484350237, + "PSNR": 32.43442219682148, + "SAM": 5.381641721873335, + "SSIM": 0.9086828993640171, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_31_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_31_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..6b8d3dc325245def9a616bd8729c1a14fc51ab99 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_31_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027213316269354252, + "MAE": 0.01822304026870589, + "PSNR": 32.43005099008614, + "SAM": 5.380723569767104, + "SSIM": 0.9086281500793518, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_32_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_32_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..1188ff86107ce758af23e58d22eb6ca9f22c0695 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_32_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027218542108819188, + "MAE": 0.01823459587795115, + "PSNR": 32.429575243876165, + "SAM": 5.380705763062456, + "SSIM": 0.9086230123762922, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_33_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_33_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..b23b5712084c77dd5e6e0ddc74891b99c6adf754 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_33_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027201242979035374, + "MAE": 0.01821136096979529, + "PSNR": 32.4347129297312, + "SAM": 5.380752220457475, + "SSIM": 0.9086639188425286, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_34_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_34_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..d1f07873d5da2a8493d04fa30f92fb2203273c53 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_34_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.02720954676425417, + "MAE": 0.01821668226351625, + "PSNR": 32.43351948486106, + "SAM": 5.378364858830772, + "SSIM": 0.9086556218518279, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_35_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_35_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..6609703db67c6dde47995450ac21f4ee9f51a3ff --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_35_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027196117778619137, + "MAE": 0.018208952186562502, + "PSNR": 32.43714808439791, + "SAM": 5.377404463051186, + "SSIM": 0.9086787752794829, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_36_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_36_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..6275fbf22dee9426c906bdca75f874f3d8908555 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_36_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.027197365488975264, + "MAE": 0.018207441006122857, + "PSNR": 32.437294496043016, + "SAM": 5.378433104420995, + "SSIM": 0.908667840788918, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_3_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_3_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..8b4f5e1aa5137c0288fb16f2ac0bb618d6b2674f --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_3_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.03848631408196211, + "MAE": 0.02771346952286548, + "PSNR": 28.855159436197052, + "SAM": 7.957361651426766, + "SSIM": 0.8585438828465877, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_4_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_4_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..bbc21bec5778ab577a2d51710d0bfae4bed33f02 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_4_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.035135230792874714, + "MAE": 0.024759766860638684, + "PSNR": 29.80269560963009, + "SAM": 7.235543962037466, + "SSIM": 0.8712200425986855, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_5_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_5_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..4be61ccdd94370e54113aaf0879dc3e4875087ce --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_5_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.03312695752650094, + "MAE": 0.022758784721772722, + "PSNR": 30.4681653999379, + "SAM": 6.992262406113135, + "SSIM": 0.8825965048859602, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_6_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_6_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..7ed8b22f98f5a5d767bf6ea50d832b16bec8c651 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_6_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.03995560772895256, + "MAE": 0.02838218824638219, + "PSNR": 28.511472895441493, + "SAM": 8.739711025814838, + "SSIM": 0.8421466158589092, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_7_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_7_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..b22cfd1fa5d522d5da19cb5dd732ca2482b52c5d --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_7_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.031056248151076646, + "MAE": 0.021148267562658724, + "PSNR": 31.0565988574695, + "SAM": 6.382153594549705, + "SSIM": 0.891711821026675, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_8_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_8_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..08995a6368f43abd7f52c14b7cd6d465f6135d3a --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_8_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.031472812372084345, + "MAE": 0.021889572141977235, + "PSNR": 30.92460130434839, + "SAM": 6.1838127149329845, + "SSIM": 0.8932244605785025, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_9_metrics.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_9_metrics.json new file mode 100644 index 0000000000000000000000000000000000000000..fe5e6145734ed505354182e2dbc509a1843d211d --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/test_epoch_9_metrics.json @@ -0,0 +1,11 @@ +{ + "RMSE": 0.029993507192348737, + "MAE": 0.020500481629842984, + "PSNR": 31.38203462532372, + "SAM": 5.94331032130549, + "SSIM": 0.8963639775338471, + "error": NaN, + "mean se": NaN, + "mean ae": NaN, + "mean var": NaN +} \ No newline at end of file diff --git a/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/trainlog.json b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/trainlog.json new file mode 100644 index 0000000000000000000000000000000000000000..5ea05734e6547eaa56c3dd68d43369ddd7579ce8 --- /dev/null +++ b/UnCRtainTS/model/results/diffcr_bs32_lr15e-4/trainlog.json @@ -0,0 +1,218 @@ +{ + "1": { + "train_epoch_time": 6364.494176864624, + "train_loss": 0.3237062679705589, + "val_epoch_time": 309.8454918861389, + "val_loss": 0.22507791634392632 + }, + "2": { + "train_epoch_time": 10007.606169939041, + "train_loss": 0.18630791476406475, + "val_epoch_time": 334.43060421943665, + "val_loss": 0.19290863794957094 + }, + "3": { + "train_epoch_time": 7983.091668128967, + "train_loss": 0.14397588547575146, + "val_epoch_time": 317.11457443237305, + "val_loss": 0.16771913052872112 + }, + "4": { + "train_epoch_time": 7319.202767133713, + "train_loss": 0.12213725253918753, + "val_epoch_time": 298.7614150047302, + "val_loss": 0.142142680658382 + }, + "5": { + "train_epoch_time": 7380.555927991867, + "train_loss": 0.10246046889589004, + "val_epoch_time": 300.55992579460144, + "val_loss": 0.13023417059950504 + }, + "6": { + "train_epoch_time": 7354.242642402649, + "train_loss": 0.09047958616789056, + "val_epoch_time": 295.45529413223267, + "val_loss": 0.18040498171137423 + }, + "7": { + "train_epoch_time": 7174.017826557159, + "train_loss": 0.07991423370502114, + "val_epoch_time": 293.46724581718445, + "val_loss": 0.11674851320895106 + }, + "8": { + "train_epoch_time": 7174.240574836731, + "train_loss": 0.07147864294231862, + "val_epoch_time": 296.465868473053, + "val_loss": 0.11874196446134987 + }, + "9": { + "train_epoch_time": 7232.939740657806, + "train_loss": 0.06549475076196906, + "val_epoch_time": 299.91336607933044, + "val_loss": 0.10886607044621519 + }, + "10": { + "train_epoch_time": 7210.83275103569, + "train_loss": 0.06024148542104809, + "val_epoch_time": 301.700749874115, + "val_loss": 0.10466061386954695 + }, + "11": { + "train_epoch_time": 7007.304115772247, + "train_loss": 0.05637170380872225, + "val_epoch_time": 309.47249245643616, + "val_loss": 0.10233035268118751 + }, + "12": { + "train_epoch_time": 6062.465890169144, + "train_loss": 0.053290489858322104, + "val_epoch_time": 308.7875895500183, + "val_loss": 0.09965699104976798 + }, + "13": { + "train_epoch_time": 6976.069507598877, + "train_loss": 0.050853383546332044, + "val_epoch_time": 301.70971298217773, + "val_loss": 0.09977608145429541 + }, + "14": { + "train_epoch_time": 6991.846319437027, + "train_loss": 0.04896079883251077, + "val_epoch_time": 300.4676983356476, + "val_loss": 0.09670896215839417 + }, + "15": { + "train_epoch_time": 6902.513485908508, + "train_loss": 0.04742359011048691, + "val_epoch_time": 299.0248591899872, + "val_loss": 0.0977693937515259 + }, + "16": { + "train_epoch_time": 7109.026990413666, + "train_loss": 0.04616981182001833, + "val_epoch_time": 303.93678188323975, + "val_loss": 0.09447520863837915 + }, + "17": { + "train_epoch_time": 7376.631589651108, + "train_loss": 0.04517159000833257, + "val_epoch_time": 298.8597710132599, + "val_loss": 0.09530528654016339 + }, + "18": { + "train_epoch_time": 7285.304846763611, + "train_loss": 0.04438288844912876, + "val_epoch_time": 303.00490379333496, + "val_loss": 0.09506756897548191 + }, + "19": { + "train_epoch_time": 7787.80112695694, + "train_loss": 0.04375889115192281, + "val_epoch_time": 303.87041115760803, + "val_loss": 0.09397591768880845 + }, + "20": { + "train_epoch_time": 7260.875463724136, + "train_loss": 0.04322601202640268, + "val_epoch_time": 299.6984910964966, + "val_loss": 0.09435483056162053 + }, + "21": { + "train_epoch_time": 7451.247015953064, + "train_loss": 0.04281983784286288, + "val_epoch_time": 323.1254963874817, + "val_loss": 0.09372285795341498 + }, + "22": { + "train_epoch_time": 7981.169721841812, + "train_loss": 0.042475839341673184, + "val_epoch_time": 291.4349844455719, + "val_loss": 0.09311704784782562 + }, + "23": { + "train_epoch_time": 7594.218820810318, + "train_loss": 0.04219691905429461, + "val_epoch_time": 297.0300626754761, + "val_loss": 0.0935743749481041 + }, + "24": { + "train_epoch_time": 7240.4726185798645, + "train_loss": 0.04198853899248015, + "val_epoch_time": 309.6284897327423, + "val_loss": 0.09330161978765895 + }, + "25": { + "train_epoch_time": 7616.718925952911, + "train_loss": 0.041811597145059366, + "val_epoch_time": 301.26167273521423, + "val_loss": 0.09353288537577575 + }, + "26": { + "train_epoch_time": 6763.542355775833, + "train_loss": 0.04166472484089874, + "val_epoch_time": 315.5020205974579, + "val_loss": 0.09341287668718982 + }, + "27": { + "train_epoch_time": 6766.121025562286, + "train_loss": 0.041551160714151644, + "val_epoch_time": 324.6825852394104, + "val_loss": 0.09334027086222464 + }, + "28": { + "train_epoch_time": 6762.440242052078, + "train_loss": 0.0414679682428255, + "val_epoch_time": 319.4597473144531, + "val_loss": 0.0931452083955651 + }, + "29": { + "train_epoch_time": 6737.97718834877, + "train_loss": 0.04137746809627797, + "val_epoch_time": 310.9894905090332, + "val_loss": 0.09332587258882731 + }, + "30": { + "train_epoch_time": 6701.353940248489, + "train_loss": 0.04132045027270899, + "val_epoch_time": 318.9266474246979, + "val_loss": 0.09306036161356857 + }, + "31": { + "train_epoch_time": 6772.876837253571, + "train_loss": 0.041288001738687276, + "val_epoch_time": 313.6783757209778, + "val_loss": 0.09314583595238843 + }, + "32": { + "train_epoch_time": 6766.960909605026, + "train_loss": 0.04124333679278965, + "val_epoch_time": 313.1634178161621, + "val_loss": 0.09321284862348306 + }, + "33": { + "train_epoch_time": 6783.178449869156, + "train_loss": 0.04120735381566611, + "val_epoch_time": 318.76844573020935, + "val_loss": 0.09307621442430057 + }, + "34": { + "train_epoch_time": 6751.531671047211, + "train_loss": 0.04118845720904739, + "val_epoch_time": 312.26704120635986, + "val_loss": 0.09316893079910082 + }, + "35": { + "train_epoch_time": 6698.0369555950165, + "train_loss": 0.041186124865399326, + "val_epoch_time": 310.79613065719604, + "val_loss": 0.09306091244751022 + }, + "36": { + "train_epoch_time": 6830.260222196579, + "train_loss": 0.04114843877437321, + "val_epoch_time": 326.72767424583435, + "val_loss": 0.09309466800422443 + } +} \ No newline at end of file diff --git a/UnCRtainTS/model/src/backbones/base_model.py b/UnCRtainTS/model/src/backbones/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6540efb6cdef1e3f93b6729b0f91cebbc0a63a1c --- /dev/null +++ b/UnCRtainTS/model/src/backbones/base_model.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn + +from src import losses, model_utils +from fvcore.nn import FlopCountAnalysis +from fvcore.nn import flop_count_table + +S2_BANDS = 13 + +class BaseModel(nn.Module): + def __init__( + self, + config + ): + super(BaseModel, self).__init__() + self.config = config # store config + self.frozen = False # no parameters are frozen + self.len_epoch = 0 # steps of one epoch + + # temporarily rescale model inputs & outputs by constant factor, e.g. from [0,1] to [0,100], + # to deal with numerical imprecision issues closeby 0 magnitude (and their inverses) + # --- convert inputs, mean & variance predictions to original scale again after NLL loss is computed + # note: this may also require adjusting the range of output nonlinearities in the generator network, + # i.e. out_mean, out_var and diag_var + + # -------------- set input via set_input and call forward --------------- + # inputs self.real_A & self.real_B set in set_input by * self.scale_by + # ------------------------------ then scale ----------------------------- + # output self.fake_B will automatically get scaled by '' + # ------------------- then compute loss via get_loss_G ------------------ + # output self.netG.variance will automatically get scaled by * self.scale_by**2 + # ----------------------------- then rescale ---------------------------- + # inputs self.real_A & self.real_B set in set_input by * 1/self.scale_by + # output self.fake_B set in self.forward by * 1/self.scale_by + # output self.netG.variance set in get_loss_G by * 1/self.scale_by**2 + self.scale_by = config.scale_by # temporarily rescale model inputs by constant factor, e.g. from [0,1] to [0,100] + + # fetch generator + self.netG = model_utils.get_generator(self.config) + + # 1 criterion + self.criterion = losses.get_loss(self.config) + self.log_vars = None + + # 2 optimizer: for G + paramsG = [{'params': self.netG.parameters()}] + + self.optimizer_G = torch.optim.Adam(paramsG, lr=config.lr) + + # 2 scheduler: for G, note: stepping takes place at the end of epoch + self.scheduler_G = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_G, gamma=self.config.gamma) + + self.real_A = None + self.fake_B = None + self.real_B = None + self.dates = None + self.masks = None + self.netG.variance = None + + def forward(self): + # forward through generator, note: for val/test splits, + # 'with torch.no_grad():' is declared in train script + self.fake_B = self.netG(self.real_A, batch_positions=self.dates) + if self.config.profile: + flopstats = FlopCountAnalysis(self.netG, (self.real_A, self.dates)) + # print(flop_count_table(flopstats)) + # TFLOPS: flopstats.total() *1e-12 + # MFLOPS: flopstats.total() *1e-6 + # compute MFLOPS per input sample + self.flops = (flopstats.total()*1e-6)/self.config.batch_size + print(f"MFLOP count: {self.flops}") + self.netG.variance = None # purge earlier variance prediction, re-compute via get_loss_G() + + def backward_G(self): + # calculate generator loss + self.get_loss_G() + self.loss_G.backward() + + + def get_loss_G(self): + + if hasattr(self.netG, 'vars_idx'): + self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :self.netG.mean_idx, ...], self.real_B, var=self.fake_B[:, :, self.netG.mean_idx:self.netG.vars_idx, ...]) + else: # used with all other models + self.loss_G, self.netG.variance = losses.calc_loss(self.criterion, self.config, self.fake_B[:, :, :S2_BANDS, ...], self.real_B, var=self.fake_B[:, :, S2_BANDS:, ...]) + + def set_input(self, input): + self.real_A = self.scale_by * input['A'].to(self.config.device) + self.real_B = self.scale_by * input['B'].to(self.config.device) + self.dates = None if input['dates'] is None else input['dates'].to(self.config.device) + self.masks = input['masks'].to(self.config.device) + + + def reset_input(self): + self.real_A = None + self.real_B = None + self.dates = None + self.masks = None + del self.real_A + del self.real_B + del self.dates + del self.masks + + + def rescale(self): + # rescale target and mean predictions + if hasattr(self, 'real_A'): self.real_A = 1/self.scale_by * self.real_A + self.real_B = 1/self.scale_by * self.real_B + self.fake_B = 1/self.scale_by * self.fake_B[:,:,:S2_BANDS,...] + + # rescale (co)variances + if hasattr(self.netG, 'variance') and self.netG.variance is not None: + self.netG.variance = 1/self.scale_by**2 * self.netG.variance + + def optimize_parameters(self): + self.forward() + del self.real_A + + # update G + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + # re-scale inputs, predicted means, predicted variances, etc + self.rescale() + # resetting inputs after optimization saves memory + self.reset_input() + + if self.netG.training: + self.fake_B = self.fake_B.cpu() + if self.netG.variance is not None: self.netG.variance = self.netG.variance.cpu() \ No newline at end of file diff --git a/UnCRtainTS/model/src/backbones/convgru.py b/UnCRtainTS/model/src/backbones/convgru.py new file mode 100644 index 0000000000000000000000000000000000000000..600864ec1d8ae18f980d868746a27ba8ca4cda6a --- /dev/null +++ b/UnCRtainTS/model/src/backbones/convgru.py @@ -0,0 +1,226 @@ +""" +Modified from https://github.com/TUM-LMF/MTLCC-pytorch/blob/master/src/models/convlstm/convlstm.py +authors: TUM-LMF +""" +import torch.nn as nn +from torch.autograd import Variable +import torch + + +class ConvGRUCell(nn.Module): + def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): + """ + Initialize ConvLSTM cell. + + Parameters + ---------- + input_size: (int, int) + Height and width of input tensor as (height, width). + input_dim: int + Number of channels of input tensor. + hidden_dim: int + Number of channels of hidden state. + kernel_size: (int, int) + Size of the convolutional kernel. + bias: bool + Whether or not to add the bias. + """ + + super(ConvGRUCell, self).__init__() + + self.height, self.width = input_size + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.kernel_size = kernel_size + self.padding = kernel_size[0] // 2, kernel_size[1] // 2 + self.bias = bias + + self.in_conv = nn.Conv2d( + in_channels=self.input_dim + self.hidden_dim, + out_channels=2 * self.hidden_dim, + kernel_size=self.kernel_size, + padding=self.padding, + bias=self.bias, + ) + self.out_conv = nn.Conv2d( + in_channels=self.input_dim + self.hidden_dim, + out_channels=self.hidden_dim, + kernel_size=self.kernel_size, + padding=self.padding, + bias=self.bias, + ) + + def forward(self, input_tensor, cur_state): + combined = torch.cat([input_tensor, cur_state], dim=1) + z, r = torch.sigmoid(self.in_conv(combined)).chunk(2, dim=1) + h = torch.tanh(self.out_conv(torch.cat([input_tensor, r * cur_state], dim=1))) + new_state = (1 - z) * cur_state + z * h + return new_state + + def init_hidden(self, batch_size, device): + return Variable( + torch.zeros(batch_size, self.hidden_dim, self.height, self.width) + ).to(device) + + +class ConvGRU(nn.Module): + def __init__( + self, + input_size, + input_dim, + hidden_dim, + kernel_size, + num_layers=1, + batch_first=True, + bias=True, + return_all_layers=False, + ): + super(ConvGRU, self).__init__() + + self._check_kernel_size_consistency(kernel_size) + + # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers + kernel_size = self._extend_for_multilayer(kernel_size, num_layers) + hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) + if not len(kernel_size) == len(hidden_dim) == num_layers: + raise ValueError("Inconsistent list length.") + + self.height, self.width = input_size + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.kernel_size = kernel_size + self.num_layers = num_layers + self.batch_first = batch_first + self.bias = bias + self.return_all_layers = return_all_layers + + cell_list = [] + for i in range(0, self.num_layers): + cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] + + cell_list.append( + ConvGRUCell( + input_size=(self.height, self.width), + input_dim=cur_input_dim, + hidden_dim=self.hidden_dim[i], + kernel_size=self.kernel_size[i], + bias=self.bias, + ) + ) + + self.cell_list = nn.ModuleList(cell_list) + + def forward( + self, input_tensor, hidden_state=None, pad_mask=None, batch_positions=None + ): + """ + + Parameters + ---------- + input_tensor: todo + 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) + hidden_state: todo + None. todo implement stateful + pad_maks (b , t) + Returns + ------- + last_state_list, layer_output + """ + if not self.batch_first: + # (t, b, c, h, w) -> (b, t, c, h, w) + input_tensor.permute(1, 0, 2, 3, 4) + + # Implement stateful ConvLSTM + if hidden_state is not None: + raise NotImplementedError() + else: + hidden_state = self._init_hidden( + batch_size=input_tensor.size(0), device=input_tensor.device + ) + + layer_output_list = [] + last_state_list = [] + + seq_len = input_tensor.size(1) + cur_layer_input = input_tensor + + for layer_idx in range(self.num_layers): + + h = hidden_state[layer_idx] + output_inner = [] + for t in range(seq_len): + h = self.cell_list[layer_idx]( + input_tensor=cur_layer_input[:, t, :, :, :], cur_state=h + ) + output_inner.append(h) + + layer_output = torch.stack(output_inner, dim=1) + if pad_mask is not None: + last_positions = (~pad_mask).sum(dim=1) - 1 + layer_output = layer_output[:, last_positions, :, :, :] + + cur_layer_input = layer_output + + layer_output_list.append(layer_output) + last_state_list.append(h) + + if not self.return_all_layers: + layer_output_list = layer_output_list[-1] + last_state_list = last_state_list[-1] + + return layer_output_list, last_state_list + + def _init_hidden(self, batch_size, device): + init_states = [] + for i in range(self.num_layers): + init_states.append(self.cell_list[i].init_hidden(batch_size, device)) + return init_states + + @staticmethod + def _check_kernel_size_consistency(kernel_size): + if not ( + isinstance(kernel_size, tuple) + or ( + isinstance(kernel_size, list) + and all([isinstance(elem, tuple) for elem in kernel_size]) + ) + ): + raise ValueError("`kernel_size` must be tuple or list of tuples") + + @staticmethod + def _extend_for_multilayer(param, num_layers): + if not isinstance(param, list): + param = [param] * num_layers + return param + + +class ConvGRU_Seg(nn.Module): + def __init__( + self, num_classes, input_size, input_dim, hidden_dim, kernel_size, pad_value=0 + ): + super(ConvGRU_Seg, self).__init__() + self.convgru_encoder = ConvGRU( + input_dim=input_dim, + input_size=input_size, + hidden_dim=hidden_dim, + kernel_size=kernel_size, + return_all_layers=False, + ) + self.classification_layer = nn.Conv2d( + in_channels=hidden_dim, + out_channels=num_classes, + kernel_size=kernel_size, + padding=1, + ) + self.pad_value = pad_value + + def forward(self, input, batch_positions=None): + pad_mask = ( + (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + ) # BxT pad mask + pad_mask = pad_mask if pad_mask.any() else None + _, out = self.convgru_encoder(input, pad_mask=pad_mask) + out = self.classification_layer(out) + return out diff --git a/UnCRtainTS/model/src/backbones/convlstm.py b/UnCRtainTS/model/src/backbones/convlstm.py new file mode 100644 index 0000000000000000000000000000000000000000..7c21d36bb4c33d5cdecf78376c8ad4e404bbf09e --- /dev/null +++ b/UnCRtainTS/model/src/backbones/convlstm.py @@ -0,0 +1,321 @@ +""" +Taken from https://github.com/TUM-LMF/MTLCC-pytorch/blob/master/src/models/convlstm/convlstm.py +authors: TUM-LMF +""" +import torch.nn as nn +from torch.autograd import Variable +import torch + + +class ConvLSTMCell(nn.Module): + def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): + """ + Initialize ConvLSTM cell. + + Parameters + ---------- + input_size: (int, int) + Height and width of input tensor as (height, width). + input_dim: int + Number of channels of input tensor. + hidden_dim: int + Number of channels of hidden state. + kernel_size: (int, int) + Size of the convolutional kernel. + bias: bool + Whether or not to add the bias. + """ + + super(ConvLSTMCell, self).__init__() + + self.height, self.width = input_size + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.kernel_size = kernel_size + self.padding = kernel_size[0] // 2, kernel_size[1] // 2 + self.bias = bias + + self.conv = nn.Conv2d( + in_channels=self.input_dim + self.hidden_dim, + out_channels=4 * self.hidden_dim, + kernel_size=self.kernel_size, + padding=self.padding, + bias=self.bias, + ) + + def forward(self, input_tensor, cur_state): + h_cur, c_cur = cur_state + + combined = torch.cat( + [input_tensor, h_cur], dim=1 + ) # concatenate along channel axis + + combined_conv = self.conv(combined) + cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) + i = torch.sigmoid(cc_i) + f = torch.sigmoid(cc_f) + o = torch.sigmoid(cc_o) + g = torch.tanh(cc_g) + + c_next = f * c_cur + i * g + h_next = o * torch.tanh(c_next) + + return h_next, c_next + + def init_hidden(self, batch_size, device): + return ( + Variable( + torch.zeros(batch_size, self.hidden_dim, self.height, self.width) + ).to(device), + Variable( + torch.zeros(batch_size, self.hidden_dim, self.height, self.width) + ).to(device), + ) + + +class ConvLSTM(nn.Module): + def __init__( + self, + input_size, + input_dim, + hidden_dim, + kernel_size, + num_layers=1, + batch_first=True, + bias=True, + return_all_layers=False, + ): + super(ConvLSTM, self).__init__() + + self._check_kernel_size_consistency(kernel_size) + + # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers + kernel_size = self._extend_for_multilayer(kernel_size, num_layers) + hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) + if not len(kernel_size) == len(hidden_dim) == num_layers: + raise ValueError("Inconsistent list length.") + + self.height, self.width = input_size + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.kernel_size = kernel_size + self.num_layers = num_layers + self.batch_first = batch_first + self.bias = bias + self.return_all_layers = return_all_layers + + cell_list = [] + for i in range(0, self.num_layers): + cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] + + cell_list.append( + ConvLSTMCell( + input_size=(self.height, self.width), + input_dim=cur_input_dim, + hidden_dim=self.hidden_dim[i], + kernel_size=self.kernel_size[i], + bias=self.bias, + ) + ) + + self.cell_list = nn.ModuleList(cell_list) + + def forward(self, input_tensor, hidden_state=None, pad_mask=None): + """ + + Parameters + ---------- + input_tensor: todo + 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) + hidden_state: todo + None. todo implement stateful + pad_maks (b , t) + Returns + ------- + last_state_list, layer_output + """ + if not self.batch_first: + # (t, b, c, h, w) -> (b, t, c, h, w) + input_tensor.permute(1, 0, 2, 3, 4) + + # Implement stateful ConvLSTM + if hidden_state is not None: + raise NotImplementedError() + else: + hidden_state = self._init_hidden( + batch_size=input_tensor.size(0), device=input_tensor.device + ) + + layer_output_list = [] + last_state_list = [] + + seq_len = input_tensor.size(1) + cur_layer_input = input_tensor + + for layer_idx in range(self.num_layers): + + h, c = hidden_state[layer_idx] + output_inner = [] + for t in range(seq_len): + h, c = self.cell_list[layer_idx]( + input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c] + ) + output_inner.append(h) + + layer_output = torch.stack(output_inner, dim=1) + if pad_mask is not None: + last_positions = (~pad_mask).sum(dim=1) - 1 + layer_output = layer_output[:, last_positions, :, :, :] + + cur_layer_input = layer_output + + layer_output_list.append(layer_output) + last_state_list.append([h, c]) + + if not self.return_all_layers: + layer_output_list = layer_output_list[-1:] + last_state_list = last_state_list[-1:] + + return layer_output_list, last_state_list + + def _init_hidden(self, batch_size, device): + init_states = [] + for i in range(self.num_layers): + init_states.append(self.cell_list[i].init_hidden(batch_size, device)) + return init_states + + @staticmethod + def _check_kernel_size_consistency(kernel_size): + if not ( + isinstance(kernel_size, tuple) + or ( + isinstance(kernel_size, list) + and all([isinstance(elem, tuple) for elem in kernel_size]) + ) + ): + raise ValueError("`kernel_size` must be tuple or list of tuples") + + @staticmethod + def _extend_for_multilayer(param, num_layers): + if not isinstance(param, list): + param = [param] * num_layers + return param + + +class ConvLSTM_Seg(nn.Module): + def __init__( + self, num_classes, input_size, input_dim, hidden_dim, kernel_size, pad_value=0 + ): + super(ConvLSTM_Seg, self).__init__() + self.convlstm_encoder = ConvLSTM( + input_dim=input_dim, + input_size=input_size, + hidden_dim=hidden_dim, + kernel_size=kernel_size, + return_all_layers=False, + ) + self.classification_layer = nn.Conv2d( + in_channels=hidden_dim, + out_channels=num_classes, + kernel_size=kernel_size, + padding=1, + ) + self.pad_value = pad_value + + def forward(self, input, batch_positions=None): + pad_mask = ( + (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + ) # BxT pad mask + pad_mask = pad_mask if pad_mask.any() else None + _, states = self.convlstm_encoder(input, pad_mask=pad_mask) + out = states[0][1] # take last cell state as embedding + out = self.classification_layer(out) + + return out + + +class BConvLSTM_Seg(nn.Module): + def __init__( + self, num_classes, input_size, input_dim, hidden_dim, kernel_size, pad_value=0 + ): + super(BConvLSTM_Seg, self).__init__() + self.convlstm_forward = ConvLSTM( + input_dim=input_dim, + input_size=input_size, + hidden_dim=hidden_dim, + kernel_size=kernel_size, + return_all_layers=False, + ) + self.convlstm_backward = ConvLSTM( + input_dim=input_dim, + input_size=input_size, + hidden_dim=hidden_dim, + kernel_size=kernel_size, + return_all_layers=False, + ) + self.classification_layer = nn.Conv2d( + in_channels=2 * hidden_dim, + out_channels=num_classes, + kernel_size=kernel_size, + padding=1, + ) + self.pad_value = pad_value + + def forward(self, input, batch_posistions=None): + pad_mask = ( + (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + ) # BxT pad mask + pad_mask = pad_mask if pad_mask.any() else None + + # FORWARD + _, forward_states = self.convlstm_forward(input, pad_mask=pad_mask) + out = forward_states[0][1] # take last cell state as embedding + + # BACKWARD + x_reverse = torch.flip(input, dims=[1]) + if pad_mask is not None: + pmr = torch.flip(pad_mask.float(), dims=[1]).bool() + x_reverse = torch.masked_fill(x_reverse, pmr[:, :, None, None, None], 0) + # Fill leading padded positions with 0s + _, backward_states = self.convlstm_backward(x_reverse) + + out = torch.cat([out, backward_states[0][1]], dim=1) + out = self.classification_layer(out) + return out + + +class BConvLSTM(nn.Module): + def __init__(self, input_size, input_dim, hidden_dim, kernel_size): + super(BConvLSTM, self).__init__() + self.convlstm_forward = ConvLSTM( + input_dim=input_dim, + input_size=input_size, + hidden_dim=hidden_dim, + kernel_size=kernel_size, + return_all_layers=False, + ) + self.convlstm_backward = ConvLSTM( + input_dim=input_dim, + input_size=input_size, + hidden_dim=hidden_dim, + kernel_size=kernel_size, + return_all_layers=False, + ) + + def forward(self, input, pad_mask=None): + # FORWARD + _, forward_states = self.convlstm_forward(input, pad_mask=pad_mask) + out = forward_states[0][1] # take last cell state as embedding + + # BACKWARD + x_reverse = torch.flip(input, dims=[1]) + if pad_mask is not None: + pmr = torch.flip(pad_mask.float(), dims=[1]).bool() + x_reverse = torch.masked_fill(x_reverse, pmr[:, :, None, None, None], 0) + # Fill leading padded positions with 0s + _, backward_states = self.convlstm_backward(x_reverse) + + out = torch.cat([out, backward_states[0][1]], dim=1) + return out diff --git a/UnCRtainTS/model/src/backbones/diffcr_no_diffusion.py b/UnCRtainTS/model/src/backbones/diffcr_no_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e5efc090ef311517b95efb15adc0a86022e874 --- /dev/null +++ b/UnCRtainTS/model/src/backbones/diffcr_no_diffusion.py @@ -0,0 +1,427 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +import math +from abc import abstractmethod + + +class EmbedBlock(nn.Module): + """ + Any module where forward() takes embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` embeddings. + """ + + +class EmbedSequential(nn.Sequential, EmbedBlock): + """ + A sequential module that passes embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, EmbedBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +def gamma_embedding(gammas, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param gammas: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, + end=half, dtype=torch.float32) / half + ).to(device=gammas.device) + args = gammas[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class CondNAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + # self.sca = nn.Sequential( + # nn.AdaptiveAvgPool2d(1), + # nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + # groups=1, bias=True), + # ) + self.sca_avg = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + self.sca_max = nn.Sequential( + nn.AdaptiveMaxPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros( + (1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x_avg, x_max = x.chunk(2, dim=1) + x_avg = self.sca_avg(x_avg)*x_avg + x_max = self.sca_max(x_max)*x_max + x = torch.cat([x_avg, x_max], dim=1) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + # self.sca = nn.Sequential( + # nn.AdaptiveAvgPool2d(1), + # nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + # groups=1, bias=True), + # ) + self.sca_avg = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + self.sca_max = nn.Sequential( + nn.AdaptiveMaxPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros( + (1, c, 1, 1)), requires_grad=True) + # self.time_emb = nn.Sequential( + # nn.SiLU(), + # nn.Linear(256, c), + # ) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x_avg, x_max = x.chunk(2, dim=1) + x_avg = self.sca_avg(x_avg)*x_avg + x_max = self.sca_max(x_max)*x_max + x = torch.cat([x_avg, x_max], dim=1) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # y = y+self.time_emb(t)[..., None, None] + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class UNCRTAINTS(nn.Module): + + def __init__( + self, + input_dim=15, + out_conv=[13], + width=64, + middle_blk_num=1, + enc_blk_nums=[1, 1, 1, 1], + dec_blk_nums=[1, 1, 1, 1], + encoder_widths=[128], + decoder_widths=[128,128,128,128,128], + out_nonlin_mean=False, + out_nonlin_var='relu', + agg_mode="att_group", + encoder_norm="group", + decoder_norm="batch", + n_head=16, + d_model=256, + d_k=4, + pad_value=0, + padding_mode="reflect", + positional_encoding=True, + covmode='diag', + scale_by=1, + separate_out=False, + use_v=False, + block_type='mbconv', + is_mono=False + ): + super().__init__() + + self.intro = nn.Conv2d(in_channels=input_dim, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + # self.cond_intro = nn.Conv2d(in_channels=img_channel+2, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + # bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=out_conv[0], kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + # self.inp_ending = nn.Conv2d(in_channels=img_channel, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1, + # bias=True) + + self.encoders = nn.ModuleList() + self.cond_encoders = nn.ModuleList() + + self.decoders = nn.ModuleList() + + self.middle_blks = nn.ModuleList() + + self.ups = nn.ModuleList() + + self.downs = nn.ModuleList() + self.cond_downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + self.cond_encoders.append( + nn.Sequential( + *[CondNAFBlock(chan) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + self.cond_downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + self.map = nn.Sequential( + nn.Linear(64, 256), + nn.SiLU(), + nn.Linear(256, 256), + ) + + def forward(self, inp): + inp = self.check_image_size(inp) + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + # b, c, h, w = cond.shape + # tmp_cond = cond.view(b//3, 3, c, h, w).sum(dim=1) + # tmp_cond = cond + # x = x + tmp_cond + encs.append(x) + x = down(x) + # cond = cond_down(cond) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + # x = x + self.inp_ending(inp) + + return x + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % + self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % + self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +if __name__ == '__main__': + # unit test for ground resolution + inp = torch.randn(1, 15, 256, 256) + net = UNCRTAINTS( + input_dim=15, + out_conv=[13], + width=64, + middle_blk_num=1, + enc_blk_nums=[1, 1, 1, 1], + dec_blk_nums=[1, 1, 1, 1], + ) + out = net(inp) + assert out.shape == (1, 13, 256, 256) + + # from thop import profile + # out_shape = (1, 12, 384, 384) + # input_shape = (1, 13, 384, 384) + # model = DiffCR( + # img_channel=13, + # width=32, + # middle_blk_num=1, + # enc_blk_nums=[1, 1, 1, 1], + # dec_blk_nums=[1, 1, 1, 1], + # ) + # # 使用 thop 的 profile 函数来获取 FLOPs 和参数量 + # flops, params = profile(model, inputs=(torch.randn(out_shape), torch.ones(1,), torch.randn(input_shape))) + # print(f"FLOPs: {flops / 1e9} G") + # print(f"Parameters: {params / 1e6} M") + diff --git a/UnCRtainTS/model/src/backbones/fpn.py b/UnCRtainTS/model/src/backbones/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..295a09ead5c6ee15176776d16285af9e89508535 --- /dev/null +++ b/UnCRtainTS/model/src/backbones/fpn.py @@ -0,0 +1,216 @@ +import torch.nn as nn +import torch + +from src.backbones.convlstm import ConvLSTM + + +class FPNConvLSTM(nn.Module): + def __init__( + self, + input_dim, + num_classes, + inconv=[32, 64], + n_levels=5, + n_channels=64, + hidden_size=88, + input_shape=(128, 128), + mid_conv=True, + pad_value=0, + ): + """ + Feature Pyramid Network with ConvLSTM baseline. + Args: + input_dim (int): Number of channels in the input images. + num_classes (int): Number of classes. + inconv (List[int]): Widths of the input convolutional layers. + n_levels (int): Number of different levels in the feature pyramid. + n_channels (int): Number of channels for each channel of the pyramid. + hidden_size (int): Hidden size of the ConvLSTM. + input_shape (int,int): Shape (H,W) of the input images. + mid_conv (bool): If True, the feature pyramid is fed to a convolutional layer + to reduce dimensionality before being given to the ConvLSTM. + pad_value (float): Padding value (temporal) used by the dataloader. + """ + super(FPNConvLSTM, self).__init__() + self.pad_value = pad_value + self.inconv = ConvBlock( + nkernels=[input_dim] + inconv, norm="group", pad_value=pad_value + ) + self.pyramid = PyramidBlock( + input_dim=inconv[-1], + n_channels=n_channels, + n_levels=n_levels, + pad_value=pad_value, + ) + + if mid_conv: + dim = n_channels * n_levels // 2 + self.mid_conv = ConvBlock( + nkernels=[self.pyramid.out_channels, dim], + pad_value=pad_value, + norm="group", + ) + else: + dim = self.pyramid.out_channels + self.mid_conv = None + + self.convlstm = ConvLSTM( + input_dim=dim, + input_size=input_shape, + hidden_dim=hidden_size, + kernel_size=(3, 3), + return_all_layers=False, + ) + + self.outconv = nn.Conv2d( + in_channels=hidden_size, out_channels=num_classes, kernel_size=1 + ) + + def forward(self, input, batch_positions=None): + pad_mask = ( + (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + ) # BxT pad mask + pad_mask = pad_mask if pad_mask.any() else None + + out = self.inconv.smart_forward(input) + out = self.pyramid.smart_forward(out) + if self.mid_conv is not None: + out = self.mid_conv.smart_forward(out) + _, out = self.convlstm(out, pad_mask=pad_mask) + out = out[0][1] + out = self.outconv(out) + + return out + + +class TemporallySharedBlock(nn.Module): + def __init__(self, pad_value=None): + super(TemporallySharedBlock, self).__init__() + self.out_shape = None + self.pad_value = pad_value + + def smart_forward(self, input): + if len(input.shape) == 4: + return self.forward(input) + else: + b, t, c, h, w = input.shape + + if self.pad_value is not None: + dummy = torch.zeros(input.shape, device=input.device).float() + self.out_shape = self.forward(dummy.view(b * t, c, h, w)).shape + + out = input.view(b * t, c, h, w) + if self.pad_value is not None: + pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + if pad_mask.any(): + temp = ( + torch.ones( + self.out_shape, device=input.device, requires_grad=False + ) + * self.pad_value + ) + temp[~pad_mask] = self.forward(out[~pad_mask]) + out = temp + else: + out = self.forward(out) + else: + out = self.forward(out) + _, c, h, w = out.shape + out = out.view(b, t, c, h, w) + return out + + +class PyramidBlock(TemporallySharedBlock): + def __init__(self, input_dim, n_levels=5, n_channels=64, pad_value=None): + """ + Feature Pyramid Block. Performs atrous convolutions with different strides + and concatenates the resulting feature maps along the channel dimension. + Args: + input_dim (int): Number of channels in the input images. + n_levels (int): Number of levels. + n_channels (int): Number of channels per level. + pad_value (float): Padding value (temporal) used by the dataloader. + """ + super(PyramidBlock, self).__init__(pad_value=pad_value) + + dilations = [2 ** i for i in range(n_levels - 1)] + self.inconv = nn.Conv2d(input_dim, n_channels, kernel_size=3, padding=1) + self.convs = nn.ModuleList( + [ + nn.Conv2d( + in_channels=n_channels, + out_channels=n_channels, + kernel_size=3, + stride=1, + dilation=d, + padding=d, + padding_mode="reflect", + ) + for d in dilations + ] + ) + + self.out_channels = n_levels * n_channels + + def forward(self, input): + out = self.inconv(input) + global_avg_pool = out.view(*out.shape[:2], -1).max(dim=-1)[0] + + out = torch.cat([cv(out) for cv in self.convs], dim=1) + + h, w = out.shape[-2:] + out = torch.cat( + [ + out, + global_avg_pool.unsqueeze(-1) + .repeat(1, 1, h) + .unsqueeze(-1) + .repeat(1, 1, 1, w), + ], + dim=1, + ) + + return out + + +class ConvLayer(nn.Module): + def __init__(self, nkernels, norm="batch", k=3, s=1, p=1, n_groups=4): + super(ConvLayer, self).__init__() + layers = [] + if norm == "batch": + nl = nn.BatchNorm2d + elif norm == "instance": + nl = nn.InstanceNorm2d + elif norm == "group": + nl = lambda num_feats: nn.GroupNorm( + num_channels=num_feats, num_groups=n_groups + ) + else: + nl = None + for i in range(len(nkernels) - 1): + layers.append( + nn.Conv2d( + in_channels=nkernels[i], + out_channels=nkernels[i + 1], + kernel_size=k, + padding=p, + stride=s, + padding_mode="reflect", + ) + ) + if nl is not None: + layers.append(nl(nkernels[i + 1])) + layers.append(nn.ReLU()) + self.conv = nn.Sequential(*layers) + + def forward(self, input): + return self.conv(input) + + +class ConvBlock(TemporallySharedBlock): + def __init__(self, nkernels, pad_value=None, norm="batch"): + super(ConvBlock, self).__init__(pad_value=pad_value) + self.conv = ConvLayer(nkernels=nkernels, norm=norm) + + def forward(self, input): + return self.conv(input) diff --git a/UnCRtainTS/model/src/backbones/ltae.py b/UnCRtainTS/model/src/backbones/ltae.py new file mode 100644 index 0000000000000000000000000000000000000000..4f28a7ec3a365f80c95a022612e002ef16ad635a --- /dev/null +++ b/UnCRtainTS/model/src/backbones/ltae.py @@ -0,0 +1,458 @@ +import copy + +import numpy as np +import torch +import torch.nn as nn + +from src.backbones.positional_encoding import PositionalEncoder + + +class LTAE2d(nn.Module): + def __init__( + self, + in_channels=128, + n_head=16, + d_k=4, + mlp=[256, 128], + dropout=0.2, + d_model=256, + T=1000, + return_att=False, + positional_encoding=True, + use_dropout=True + ): + """ + Lightweight Temporal Attention Encoder (L-TAE) for image time series. + Attention-based sequence encoding that maps a sequence of images to a single feature map. + A shared L-TAE is applied to all pixel positions of the image sequence. + Args: + in_channels (int): Number of channels of the input embeddings. + n_head (int): Number of attention heads. + d_k (int): Dimension of the key and query vectors. + mlp (List[int]): Widths of the layers of the MLP that processes the concatenated outputs of the attention heads. + dropout (float): dropout on the MLP-processed values + d_model (int, optional): If specified, the input tensors will first processed by a fully connected layer + to project them into a feature space of dimension d_model. + T (int): Period to use for the positional encoding. + return_att (bool): If true, the module returns the attention masks along with the embeddings (default False) + positional_encoding (bool): If False, no positional encoding is used (default True). + use_dropout (bool): dropout on the attention masks. + """ + super(LTAE2d, self).__init__() + self.in_channels = in_channels + self.mlp = copy.deepcopy(mlp) + self.return_att = return_att + self.n_head = n_head + + if d_model is not None: + self.d_model = d_model + self.inconv = nn.Conv1d(in_channels, d_model, 1) + else: + self.d_model = in_channels + self.inconv = None + assert self.mlp[0] == self.d_model + + if positional_encoding: + self.positional_encoder = PositionalEncoder( + self.d_model // n_head, T=T, repeat=n_head + ) + else: + self.positional_encoder = None + + self.attention_heads = MultiHeadAttention( + n_head=n_head, d_k=d_k, d_in=self.d_model, use_dropout=use_dropout + ) + self.in_norm = nn.GroupNorm( + num_groups=n_head, + num_channels=self.in_channels, + ) + self.out_norm = nn.GroupNorm( + num_groups=n_head, + num_channels=mlp[-1], + ) + + layers = [] + for i in range(len(self.mlp) - 1): + layers.extend( + [ + nn.Linear(self.mlp[i], self.mlp[i + 1]), + nn.BatchNorm1d(self.mlp[i + 1]), + nn.ReLU(), + ] + ) + + self.mlp = nn.Sequential(*layers) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, batch_positions=None, pad_mask=None, return_comp=False): + sz_b, seq_len, d, h, w = x.shape + if pad_mask is not None: + pad_mask = ( + pad_mask.unsqueeze(-1) + .repeat((1, 1, h)) + .unsqueeze(-1) + .repeat((1, 1, 1, w)) + ) # BxTxHxW + pad_mask = ( + pad_mask.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len) + ) + + out = x.permute(0, 3, 4, 1, 2).contiguous().view(sz_b * h * w, seq_len, d) + out = self.in_norm(out.permute(0, 2, 1)).permute(0, 2, 1) + + if self.inconv is not None: + out = self.inconv(out.permute(0, 2, 1)).permute(0, 2, 1) + + if self.positional_encoder is not None: + bp = ( + batch_positions.unsqueeze(-1) + .repeat((1, 1, h)) + .unsqueeze(-1) + .repeat((1, 1, 1, w)) + ) # BxTxHxW + bp = bp.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len) + out = out + self.positional_encoder(bp) + + # re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4]) + # in utae.py this is torch.Size([h, B, T, 32, 32]) + # re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16]) + # in utae.py this is torch.Size([B, 128, 32, 32]) + out, attn = self.attention_heads(out, pad_mask=pad_mask) + + out = ( + out.permute(1, 0, 2).contiguous().view(sz_b * h * w, -1) + ) # Concatenate heads, out is now [B*H*W x d_in/h * h], e.g. [2048 x 256] + + # out is of shape [head x b x t x h x w] + out = self.dropout(self.mlp(out)) + # after MLP, out is of shape [B*H*W x outputLayerOfMLP], e.g. [2048 x 128] + out = self.out_norm(out) if self.out_norm is not None else out + out = out.view(sz_b, h, w, -1).permute(0, 3, 1, 2) + + attn = attn.view(self.n_head, sz_b, h, w, seq_len).permute( + 0, 1, 4, 2, 3 + ) + + # out is of shape [B x outputLayerOfMLP x h x w], e.g. [2, 128, 32, 32] + # attn is of shape [h x B x T x H x W], e.g. [16, 2, 4, 32, 32] + if self.return_att: + return out, attn + else: + return out + + + +class LTAE2dtiny(nn.Module): + def __init__( + self, + in_channels=128, + n_head=16, + d_k=4, + d_model=256, + T=1000, + positional_encoding=True, + ): + """ + Lightweight Temporal Attention Encoder (L-TAE) for image time series. + Attention-based sequence encoding that maps a sequence of images to a single feature map. + A shared L-TAE is applied to all pixel positions of the image sequence. + This is the tiny version, which stops further processing attention-weighted values v + (no longer using an MLP) and only returns the attention matrix attn itself + Args: + in_channels (int): Number of channels of the input embeddings. + n_head (int): Number of attention heads. + d_k (int): Dimension of the key and query vectors. + d_model (int, optional): If specified, the input tensors will first processed by a fully connected layer + to project them into a feature space of dimension d_model. + T (int): Period to use for the positional encoding. + positional_encoding (bool): If False, no positional encoding is used (default True). + """ + super(LTAE2dtiny, self).__init__() + self.in_channels = in_channels + self.n_head = n_head + + if d_model is not None: + self.d_model = d_model + self.inconv = nn.Conv1d(in_channels, d_model, 1) + else: + self.d_model = in_channels + self.inconv = None + + if positional_encoding: + self.positional_encoder = PositionalEncoder( + self.d_model // n_head, T=T, repeat=n_head + ) + else: + self.positional_encoder = None + + self.attention_heads = MultiHeadAttentionSmall( + n_head=n_head, d_k=d_k, d_in=self.d_model + ) + self.in_norm = nn.GroupNorm( + num_groups=n_head, + num_channels=self.in_channels, + ) + + + def forward(self, x, batch_positions=None, pad_mask=None): + sz_b, seq_len, d, h, w = x.shape + if pad_mask is not None: + pad_mask = ( + pad_mask.unsqueeze(-1) + .repeat((1, 1, h)) + .unsqueeze(-1) + .repeat((1, 1, 1, w)) + ) # BxTxHxW + pad_mask = ( + pad_mask.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len) + ) + + out = x.permute(0, 3, 4, 1, 2).contiguous().view(sz_b * h * w, seq_len, d) + out = self.in_norm(out.permute(0, 2, 1)).permute(0, 2, 1) + + if self.inconv is not None: + out = self.inconv(out.permute(0, 2, 1)).permute(0, 2, 1) + + if self.positional_encoder is not None: + bp = ( + batch_positions.unsqueeze(-1) + .repeat((1, 1, h)) + .unsqueeze(-1) + .repeat((1, 1, 1, w)) + ) # BxTxHxW + bp = bp.permute(0, 2, 3, 1).contiguous().view(sz_b * h * w, seq_len) + out = out + self.positional_encoder(bp) + + # re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4]) + # in utae.py this is torch.Size([h, B, T, 32, 32]) + # re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16]) + # in utae.py this is torch.Size([B, 128, 32, 32]) + attn = self.attention_heads(out, pad_mask=pad_mask) + + + attn = attn.view(self.n_head, sz_b, h, w, seq_len).permute( + 0, 1, 4, 2, 3 + ) + + # out is of shape [B x outputLayerOfMLP x h x w], e.g. [2, 128, 32, 32] + # attn is of shape [h x B x T x H x W], e.g. [16, 2, 4, 32, 32] + return attn + + +# this class still uses ScaledDotProductAttention (including dropout) +# and always computes and returns att*v +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module + Modified from github.com/jadore801120/attention-is-all-you-need-pytorch + """ + + def __init__(self, n_head, d_k, d_in, use_dropout=True): + super().__init__() + self.n_head = n_head + self.d_k = d_k + self.d_in = d_in # e.g. self.d_model in LTAE2d + + # define H x k queries, they are input-independent in LTAE + self.Q = nn.Parameter(torch.zeros((n_head, d_k))).requires_grad_(True) + nn.init.normal_(self.Q, mean=0, std=np.sqrt(2.0 / (d_k))) + + self.fc1_k = nn.Linear(d_in, n_head * d_k) + nn.init.normal_(self.fc1_k.weight, mean=0, std=np.sqrt(2.0 / (d_k))) + + attn_dropout=0.1 if use_dropout else 0.0 + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), attn_dropout=attn_dropout) + + def forward(self, v, pad_mask=None, return_comp=False): + d_k, d_in, n_head = self.d_k, self.d_in, self.n_head + # values v are of shapes [B*H*W, T, self.d_in=self.d_model], e.g. [2*32*32=2048 x 4 x 256] (see: sz_b * h * w, seq_len, d) + # where self.d_in=self.d_model is the output dimension of the FC-projected features + sz_b, seq_len, _ = v.size() + + q = torch.stack([self.Q for _ in range(sz_b)], dim=1).view(-1, d_k) # (n*b) x d_k + + k = self.fc1_k(v).view(sz_b, seq_len, n_head, d_k) + k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) # (n*b) x lk x dk + + if pad_mask is not None: + pad_mask = pad_mask.repeat( + (n_head, 1) + ) # replicate pad_mask for each head (nxb) x lk + + # attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4], e.g. Size([32768, 1, 4]) + # v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16], e.g. Size([32768, 4, 16]) + # output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16]) + v = torch.stack(v.split(v.shape[-1] // n_head, dim=-1)).view(n_head * sz_b, seq_len, -1) + if return_comp: + output, attn, comp = self.attention( + q, k, v, pad_mask=pad_mask, return_comp=return_comp + ) + else: + output, attn = self.attention( + q, k, v, pad_mask=pad_mask, return_comp=return_comp + ) + + attn = attn.view(n_head, sz_b, 1, seq_len) + attn = attn.squeeze(dim=2) + + output = output.view(n_head, sz_b, 1, d_in // n_head) + output = output.squeeze(dim=2) + + # re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4]) + # in utae.py this is torch.Size([h, B, T, 32, 32]) + # re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16]) + # in utae.py this is torch.Size([B, 128, 32, 32]) + if return_comp: + return output, attn, comp + else: + return output, attn + + +# this class uses ScaledDotProductAttentionSmall (excluding dropout) +# and only optionally computes and returns att*v +class MultiHeadAttentionSmall(nn.Module): + """Multi-Head Attention module + Modified from github.com/jadore801120/attention-is-all-you-need-pytorch + """ + + def __init__(self, n_head, d_k, d_in): + super().__init__() + self.n_head = n_head # e.g. 16 + self.d_k = d_k # e.g. 4, number of keys per head + self.d_in = d_in # e.g. 256, self.d_model in LTAE2d + + # define H x k queries, they are input-independent in LTAE + self.Q = nn.Parameter(torch.zeros((n_head, d_k))).requires_grad_(True) + nn.init.normal_(self.Q, mean=0, std=np.sqrt(2.0 / (d_k))) + + self.fc1_k = nn.Linear(d_in, n_head * d_k) + """ + # consider using deeper mappings with nonlinearities, + # but this is somewhat against the original Transformer spirit + self.fc1_k = nn.Linear(d_in, d_in) + self.bn2_k = nn.BatchNorm1d(d_in) + self.fc2_k = nn.Linear(d_in, n_head * d_k) + self.bn2_k = nn.BatchNorm1d(n_head * d_k) + """ + + nn.init.normal_(self.fc1_k.weight, mean=0, std=np.sqrt(2.0 / (d_k))) + #nn.init.normal_(self.fc2_k.weight, mean=0, std=np.sqrt(2.0 / (d_k))) + self.attention = ScaledDotProductAttentionSmall(temperature=np.power(d_k, 0.5)) + + def forward(self, v, pad_mask=None, return_comp=False, weight_v=False): + d_k, d_in, n_head = self.d_k, self.d_in, self.n_head + # values v are of shapes [B*H*W, T, self.d_in=self.d_model], e.g. [2*32*32=2048 x 4 x 256] (see: sz_b * h * w, seq_len, d) + # where self.d_in=self.d_model is the output dimension of the FC-projected features + sz_b, seq_len, _ = v.size() + + q = torch.stack([self.Q for _ in range(sz_b)], dim=1).view(-1, d_k) # (n*b) x d_k + + k = self.fc1_k(v).view(sz_b, seq_len, n_head, d_k) + k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) # (n*b) x lk x dk + + if pad_mask is not None: + pad_mask = pad_mask.repeat( + (n_head, 1) + ) # replicate pad_mask for each head (nxb) x lk + + # attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4], e.g. Size([32768, 1, 4]) + # v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16], e.g. Size([32768, 4, 16]) + # output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16]) + v = torch.stack(v.split(v.shape[-1] // n_head, dim=-1)).view(n_head * sz_b, seq_len, -1) + if weight_v: + output, attn = self.attention(q, k, v, pad_mask=pad_mask, return_comp=return_comp, weight_v=weight_v) + if return_comp: + output, attn, comp = self.attention(q, k, v, pad_mask=pad_mask, return_comp=return_comp, weight_v=weight_v) + else: + attn = self.attention(q, k, v, pad_mask=pad_mask, return_comp=return_comp, weight_v=weight_v) + + attn = attn.view(n_head, sz_b, 1, seq_len) + attn = attn.squeeze(dim=2) + + if weight_v: + output = output.view(n_head, sz_b, 1, d_in // n_head) + output = output.squeeze(dim=2) + + # re-shaped attn to [h x B*H*W x T], e.g. torch.Size([16, 2048, 4]) + # in utae.py this is torch.Size([h, B, T, 32, 32]) + # re-shaped output to [h x B*H*W x d_in/h], e.g. torch.Size([16, 2048, 16]) + # in utae.py this is torch.Size([B, 128, 32, 32]) + + if return_comp: + return output, attn, comp + else: + return output, attn + + return attn + + +class ScaledDotProductAttention(nn.Module): + """Scaled Dot-Product Attention + Modified from github.com/jadore801120/attention-is-all-you-need-pytorch + """ + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v, pad_mask=None, return_comp=False): + attn = torch.matmul(q.unsqueeze(1), k.transpose(1, 2)) + attn = attn / self.temperature + if pad_mask is not None: + attn = attn.masked_fill(pad_mask.unsqueeze(1), -1e3) + if return_comp: + comp = attn + # attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4] + # v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16] + # output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16]) + attn = self.softmax(attn) + attn = self.dropout(attn) + output = torch.matmul(attn, v) + + if return_comp: + return output, attn, comp + else: + return output, attn + +# no longer using dropout (before upsampling) +# but optionally doing attn*v weighting +class ScaledDotProductAttentionSmall(nn.Module): + """Scaled Dot-Product Attention + Modified from github.com/jadore801120/attention-is-all-you-need-pytorch + """ + + def __init__(self, temperature): + super().__init__() + self.temperature = temperature + #self.dropout = nn.Dropout(attn_dropout) # moved dropout after bilinear interpolation + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v, pad_mask=None, return_comp=False, weight_v=False): + attn = torch.matmul(q.unsqueeze(1), k.transpose(1, 2)) + attn = attn / self.temperature + if pad_mask is not None: + attn = attn.masked_fill(pad_mask.unsqueeze(1), -1e3) + if return_comp: + comp = attn + # attn is of shape [B*H*W*h, 1, T], e.g. [2*32*32*16=32768 x 1 x 4] + # v is of shape [B*H*W*h, T, self.d_in/h], e.g. [2*32*32*16=32768 x 4 x 256/16=16] + # output is of shape [B*H*W*h, 1, h], e.g. [2*32*32*16=32768 x 1 x 16], e.g. Size([32768, 1, 16]) + attn = self.softmax(attn) + + """ + # no longer using dropout on attention matrices before the upsampling + # this is now done after bilinear interpolation only + + attn = self.dropout(attn) + """ + + if weight_v: + # optionally using the weighted values + output = torch.matmul(attn, v) + + if return_comp: + return output, attn, comp + else: + return output, attn + return attn \ No newline at end of file diff --git a/UnCRtainTS/model/src/backbones/positional_encoding.py b/UnCRtainTS/model/src/backbones/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..7e85c2a2e8867f9d25dcd065413cdcc3d3455b28 --- /dev/null +++ b/UnCRtainTS/model/src/backbones/positional_encoding.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + + +class PositionalEncoder(nn.Module): + def __init__(self, d, T=1000, repeat=None, offset=0): + super(PositionalEncoder, self).__init__() + self.d = d + self.T = T + self.repeat = repeat + self.denom = torch.pow( + T, 2 * (torch.arange(offset, offset + d).float() // 2) / d + ) + self.updated_location = False + + def forward(self, batch_positions): + if not self.updated_location: + self.denom = self.denom.to(batch_positions.device) + self.updated_location = True + sinusoid_table = ( + batch_positions[:, :, None] / self.denom[None, None, :] + ) # B x T x C + sinusoid_table[:, :, 0::2] = torch.sin(sinusoid_table[:, :, 0::2]) # dim 2i + sinusoid_table[:, :, 1::2] = torch.cos(sinusoid_table[:, :, 1::2]) # dim 2i+1 + + if self.repeat is not None: + sinusoid_table = torch.cat( + [sinusoid_table for _ in range(self.repeat)], dim=-1 + ) + + return sinusoid_table diff --git a/UnCRtainTS/model/src/backbones/uncrtaints.py b/UnCRtainTS/model/src/backbones/uncrtaints.py new file mode 100644 index 0000000000000000000000000000000000000000..f98bfa877bf88e8f0f76fdc464b41b19b23b6ad8 --- /dev/null +++ b/UnCRtainTS/model/src/backbones/uncrtaints.py @@ -0,0 +1,888 @@ +""" +UnCRtainTS Implementation +Author: Patrick Ebel (github/patrickTUM) +License: MIT +""" + +import torch +import torch.nn as nn +import sys +sys.path.append("./model") +from src.backbones.utae import ConvLayer, ConvBlock, TemporallySharedBlock +from src.backbones.ltae import LTAE2d, LTAE2dtiny + +S2_BANDS = 13 + + +def get_norm_layer(out_channels, num_feats, n_groups=4, layer_type='batch'): + if layer_type == 'batch': + return nn.BatchNorm2d(out_channels) + elif layer_type == 'instance': + return nn.InstanceNorm2d(out_channels) + elif layer_type == 'group': + return nn.GroupNorm(num_channels=num_feats, num_groups=n_groups) + +class ResidualConvBlock(TemporallySharedBlock): + def __init__( + self, + nkernels, + pad_value=None, + norm="batch", + n_groups=4, + #last_relu=True, + k=3, s=1, p=1, + padding_mode="reflect", + ): + super(ResidualConvBlock, self).__init__(pad_value=pad_value) + + self.conv1 = ConvLayer( + nkernels=nkernels, + norm=norm, + last_relu=True, + k=k, s=s, p=p, + n_groups=n_groups, + padding_mode=padding_mode, + ) + self.conv2 = ConvLayer( + nkernels=nkernels, + norm=norm, + last_relu=True, + k=k, s=s, p=p, + n_groups=n_groups, + padding_mode=padding_mode, + ) + self.conv3 = ConvLayer( + nkernels=nkernels, + #norm='none', + #last_relu=False, + norm=norm, + last_relu=True, + k=k, s=s, p=p, + n_groups=n_groups, + padding_mode=padding_mode, + ) + + def forward(self, input): + + out1 = self.conv1(input) # followed by built-in ReLU & norm + out2 = self.conv2(out1) # followed by built-in ReLU & norm + out3 = input + self.conv3(out2) # omit norm & ReLU + return out3 + + +class PreNorm(nn.Module): + def __init__(self, dim, fn, norm, n_groups=4): + super().__init__() + self.norm = get_norm_layer(dim, dim, n_groups, norm) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class SE(nn.Module): + def __init__(self, inp, oup, expansion=0.25): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(oup, int(inp * expansion), bias=False), + nn.GELU(), + nn.Linear(int(inp * expansion), oup, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class MBConv(TemporallySharedBlock): + def __init__(self, inp, oup, downsample=False, expansion=4, norm='batch', n_groups=4): + super().__init__() + self.downsample = downsample + stride = 1 if self.downsample == False else 2 + hidden_dim = int(inp * expansion) + + if self.downsample: + self.pool = nn.MaxPool2d(3, 2, 1) + self.proj = nn.Conv2d(inp, oup, 1, stride=1, padding=0, bias=False) + + if expansion == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, + padding=1, padding_mode='reflect', groups=hidden_dim, bias=False), + get_norm_layer(hidden_dim, hidden_dim, n_groups, norm), + nn.GELU(), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, stride=1, padding=0, bias=False), + get_norm_layer(oup, oup, n_groups, norm), + ) + else: + self.conv = nn.Sequential( + # pw + # down-sample in the first conv + nn.Conv2d(inp, hidden_dim, 1, stride=stride, padding=0, bias=False), + get_norm_layer(hidden_dim, hidden_dim, n_groups, norm), + nn.GELU(), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1, padding_mode='reflect', + groups=hidden_dim, bias=False), + get_norm_layer(hidden_dim, hidden_dim, n_groups, norm), + nn.GELU(), + SE(inp, hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, stride=1, padding=0, bias=False), + get_norm_layer(oup, oup, n_groups, norm), + ) + + self.conv = PreNorm(inp, self.conv, norm, n_groups=4) + + def forward(self, x): + if self.downsample: + return self.proj(self.pool(x)) + self.conv(x) + else: + return x + self.conv(x) + + +class Compact_Temporal_Aggregator(nn.Module): + def __init__(self, mode="mean"): + super(Compact_Temporal_Aggregator, self).__init__() + self.mode = mode + # moved dropout from ScaledDotProductAttention to here, applied after upsampling + self.attn_dropout = nn.Dropout(0.1) # no dropout via: nn.Dropout(0.0) + + def forward(self, x, pad_mask=None, attn_mask=None): + if pad_mask is not None and pad_mask.any(): + if self.mode == "att_group": + n_heads, b, t, h, w = attn_mask.shape + attn = attn_mask.view(n_heads * b, t, h, w) + + if x.shape[-2] > w: + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + # this got moved out of ScaledDotProductAttention, apply after upsampling + attn = self.attn_dropout(attn) + else: + attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn) + + attn = attn.view(n_heads, b, t, *x.shape[-2:]) + attn = attn * (~pad_mask).float()[None, :, :, None, None] + + out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW + out = attn[:, :, :, None, :, :] * out + out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW + out = torch.cat([group for group in out], dim=1) # -> BxCxHxW + return out + elif self.mode == "att_mean": + attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + # this got moved out of ScaledDotProductAttention, apply after upsampling + attn = self.attn_dropout(attn) + attn = attn * (~pad_mask).float()[:, :, None, None] + out = (x * attn[:, :, None, :, :]).sum(dim=1) + return out + elif self.mode == "mean": + out = x * (~pad_mask).float()[:, :, None, None, None] + out = out.sum(dim=1) / (~pad_mask).sum(dim=1)[:, None, None, None] + return out + else: + if self.mode == "att_group": + n_heads, b, t, h, w = attn_mask.shape + attn = attn_mask.view(n_heads * b, t, h, w) + if x.shape[-2] > w: + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + # this got moved out of ScaledDotProductAttention, apply after upsampling + attn = self.attn_dropout(attn) + else: + attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn) + attn = attn.view(n_heads, b, t, *x.shape[-2:]) + out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW + out = attn[:, :, :, None, :, :] * out + out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW + out = torch.cat([group for group in out], dim=1) # -> BxCxHxW + return out + elif self.mode == "att_mean": + attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + # this got moved out of ScaledDotProductAttention, apply after upsampling + attn = self.attn_dropout(attn) + out = (x * attn[:, :, None, :, :]).sum(dim=1) + return out + elif self.mode == "mean": + return x.mean(dim=1) + +def get_nonlinearity(mode, eps): + if mode=='relu': fct = nn.ReLU() + elif mode=='softplus': fct = lambda vars:nn.Softplus(beta=1, threshold=20)(vars) + eps + elif mode=='elu': fct = lambda vars: nn.ELU()(vars) + 1 + eps + else: fct = nn.Identity() + return fct + +# class UNCRTAINTS(nn.Module): +# def __init__( +# self, +# input_dim, +# encoder_widths=[128], +# decoder_widths=[128,128,128,128,128], +# out_conv=[S2_BANDS], +# out_nonlin_mean=False, +# out_nonlin_var='relu', +# agg_mode="att_group", +# encoder_norm="group", +# decoder_norm="batch", +# n_head=16, +# d_model=256, +# d_k=4, +# pad_value=0, +# padding_mode="reflect", +# positional_encoding=True, +# covmode='diag', +# scale_by=1, +# separate_out=False, +# use_v=False, +# block_type='mbconv', +# is_mono=False +# ): +# """ +# UnCRtainTS architecture for spatio-temporal encoding of satellite image time series. +# Args: +# input_dim (int): Number of channels in the input images. +# encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder. +# This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1) +# in the architecture. +# The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution. +# decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of +# channels should be given is also from top to bottom. If this argument is not specified the decoder +# will have the same configuration as the encoder. +# out_conv (List[int]): Number of channels of the successive convolutions for the +# agg_mode (str): Aggregation mode for the skip connections. Can either be: +# - att_group (default) : Attention weighted temporal average, using the same +# channel grouping strategy as in the LTAE. The attention masks are bilinearly +# resampled to the resolution of the skipped feature maps. +# - att_mean : Attention weighted temporal average, +# using the average attention scores across heads for each date. +# - mean : Temporal average excluding padded dates. +# encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be: +# - group : GroupNorm (default) +# - batch : BatchNorm +# - instance : InstanceNorm +# - none: apply no normalization +# decoder_norm (str): similar to encoder_norm +# n_head (int): Number of heads in LTAE. +# d_model (int): Parameter of LTAE +# d_k (int): Key-Query space dimension +# pad_value (float): Value used by the dataloader for temporal padding. +# padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2d). +# positional_encoding (bool): If False, no positional encoding is used (default True). +# """ +# super(UNCRTAINTS, self).__init__() +# self.n_stages = len(encoder_widths) +# self.encoder_widths = encoder_widths +# self.decoder_widths = decoder_widths +# self.out_widths = out_conv +# self.is_mono = is_mono +# self.use_v = use_v +# self.block_type = block_type + +# self.enc_dim = decoder_widths[0] if decoder_widths is not None else encoder_widths[0] +# self.stack_dim = sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) +# self.pad_value = pad_value +# self.padding_mode = padding_mode + +# self.scale_by = scale_by +# self.separate_out = separate_out # define two separate layer streams for mean and variance predictions + +# if decoder_widths is not None: +# assert encoder_widths[-1] == decoder_widths[-1] +# else: decoder_widths = encoder_widths + + +# # ENCODER +# self.in_conv = ConvBlock( +# nkernels=[input_dim] + [encoder_widths[0]], +# k=1, s=1, p=0, +# norm=encoder_norm, +# ) + +# if self.block_type=='mbconv': +# self.in_block = nn.ModuleList([MBConv(layer, layer, downsample=False, expansion=2, norm=encoder_norm) for layer in encoder_widths]) +# elif self.block_type=='residual': +# self.in_block = nn.ModuleList([ResidualConvBlock(nkernels=[layer]+[layer], k=3, s=1, p=1, norm=encoder_norm, n_groups=4) for layer in encoder_widths]) +# else: raise NotImplementedError + +# if not self.is_mono: +# # LTAE +# if self.use_v: +# # same as standard LTAE, except we don't apply dropout on the low-resolution attention masks +# self.temporal_encoder = LTAE2d( +# in_channels=encoder_widths[0], +# d_model=d_model, +# n_head=n_head, +# mlp=[d_model, encoder_widths[0]], # MLP to map v, only used if self.use_v=True +# return_att=True, +# d_k=d_k, +# positional_encoding=positional_encoding, +# use_dropout=False +# ) +# # linearly combine mask-weighted +# v_dim = encoder_widths[0] +# self.include_v = nn.Conv2d(encoder_widths[0]+v_dim, encoder_widths[0], 1) +# else: +# self.temporal_encoder = LTAE2dtiny( +# in_channels=encoder_widths[0], +# d_model=d_model, +# n_head=n_head, +# d_k=d_k, +# positional_encoding=positional_encoding, +# ) + +# self.temporal_aggregator = Compact_Temporal_Aggregator(mode=agg_mode) + +# if self.block_type=='mbconv': +# self.out_block = nn.ModuleList([MBConv(layer, layer, downsample=False, expansion=2, norm=decoder_norm) for layer in decoder_widths]) +# elif self.block_type=='residual': +# self.out_block = nn.ModuleList([ResidualConvBlock(nkernels=[layer]+[layer], k=3, s=1, p=1, norm=decoder_norm, n_groups=4) for layer in decoder_widths]) +# else: raise NotImplementedError + + +# self.covmode = covmode +# if covmode=='uni': +# # batching across channel dimension +# covar_dim = S2_BANDS +# elif covmode=='iso': +# covar_dim = 1 +# elif covmode=='diag': +# covar_dim = S2_BANDS +# else: covar_dim = 0 + +# self.mean_idx = S2_BANDS +# self.vars_idx = self.mean_idx + covar_dim + +# # note: not including normalization layer and ReLU nonlinearity into the final ConvBlock +# # if inserting >1 layers into out_conv then consider treating normalizations separately +# self.out_dims = out_conv[-1] + +# eps = 1e-9 if self.scale_by==1.0 else 1e-3 + +# if self.separate_out: # define two separate layer streams for mean and variance predictions +# self.out_conv_mean_1 = ConvBlock(nkernels=[decoder_widths[0]] + [S2_BANDS], k=1, s=1, p=0, norm='none', last_relu=False) +# if self.out_dims - self.mean_idx > 0: +# self.out_conv_var_1 = ConvBlock(nkernels=[decoder_widths[0]] + [self.out_dims - S2_BANDS], k=1, s=1, p=0, norm='none', last_relu=False) +# else: +# self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, norm='none', last_relu=False) + +# # set output nonlinearities +# if out_nonlin_mean: self.out_mean = lambda vars: self.scale_by * nn.Sigmoid()(vars) # this is for predicting mean values in [0, 1] +# else: self.out_mean = nn.Identity() # just keep the mean estimates, without applying a nonlinearity + +# if self.covmode in ['uni', 'iso', 'diag']: +# self.diag_var = get_nonlinearity(out_nonlin_var, eps) + + +# def forward(self, input, batch_positions=None): +# print(input.shape) +# pad_mask = ( +# (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) +# ) # BxT pad mask +# # SPATIAL ENCODER +# # collect feature maps in list 'feature_maps' +# out = self.in_conv.smart_forward(input) + +# for layer in self.in_block: +# out = layer.smart_forward(out) + +# if not self.is_mono: +# att_down = 32 +# down = nn.AdaptiveMaxPool2d((att_down, att_down))(out.view(out.shape[0] * out.shape[1], *out.shape[2:])).view(out.shape[0], out.shape[1], out.shape[2], att_down, att_down) + +# # TEMPORAL ENCODER +# if self.use_v: +# v, att = self.temporal_encoder(down, batch_positions=batch_positions, pad_mask=pad_mask) +# else: +# att = self.temporal_encoder(down, batch_positions=batch_positions, pad_mask=pad_mask) + +# out = self.temporal_aggregator(out, pad_mask=pad_mask, attn_mask=att) + +# if self.use_v: +# # upsample values to input resolution, then linearly combine with attention masks +# up_v = nn.Upsample(size=(out.shape[-2:]), mode="bilinear", align_corners=False)(v) +# out = self.include_v(torch.cat((out, up_v), dim=1)) +# else: out = out.squeeze(dim=1) + +# # SPATIAL DECODER +# for layer in self.out_block: +# out = layer.smart_forward(out) + +# if self.separate_out: +# out_mean_1 = self.out_conv_mean_1(out) + +# if self.out_dims - self.mean_idx > 0: +# out_var_1 = self.out_conv_var_1(out) +# out = torch.cat((out_mean_1, out_var_1), dim=1) +# else: out = out_mean_1 #out = out_mean_2 +# else: +# out = self.out_conv(out) # predict mean and var in single layer + + +# # append a singelton temporal dimension such that outputs are [B x T=1 x C x H x W] +# out = out.unsqueeze(dim=1) + +# # apply output nonlinearities + +# # get mean predictions +# out_loc = self.out_mean(out[:,:,:self.mean_idx,...]) # mean predictions in [0,1] +# if not self.covmode: return out_loc + +# out_cov = self.diag_var(out[:,:,self.mean_idx:self.vars_idx,...]) # var predictions > 0 +# out = torch.cat((out_loc, out_cov), dim=2) # stack mean and var predictions plus cloud masks +# print(f"{out.shape}") +# return out + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +import math +from abc import abstractmethod + + +class EmbedBlock(nn.Module): + """ + Any module where forward() takes embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` embeddings. + """ + + +class EmbedSequential(nn.Sequential, EmbedBlock): + """ + A sequential module that passes embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, EmbedBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +def gamma_embedding(gammas, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param gammas: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, + end=half, dtype=torch.float32) / half + ).to(device=gammas.device) + args = gammas[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class CondNAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + # self.sca = nn.Sequential( + # nn.AdaptiveAvgPool2d(1), + # nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + # groups=1, bias=True), + # ) + self.sca_avg = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + self.sca_max = nn.Sequential( + nn.AdaptiveMaxPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros( + (1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x_avg, x_max = x.chunk(2, dim=1) + x_avg = self.sca_avg(x_avg)*x_avg + x_max = self.sca_max(x_max)*x_max + x = torch.cat([x_avg, x_max], dim=1) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + # self.sca = nn.Sequential( + # nn.AdaptiveAvgPool2d(1), + # nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + # groups=1, bias=True), + # ) + self.sca_avg = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + self.sca_max = nn.Sequential( + nn.AdaptiveMaxPool2d(1), + nn.Conv2d(in_channels=dw_channel // 4, out_channels=dw_channel // 4, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, + kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout( + drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros( + (1, c, 1, 1)), requires_grad=True) + # self.time_emb = nn.Sequential( + # nn.SiLU(), + # nn.Linear(256, c), + # ) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x_avg, x_max = x.chunk(2, dim=1) + x_avg = self.sca_avg(x_avg)*x_avg + x_max = self.sca_max(x_max)*x_max + x = torch.cat([x_avg, x_max], dim=1) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # y = y+self.time_emb(t)[..., None, None] + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + +class UNCRTAINTS(nn.Module): + + def __init__( + self, + input_dim=15, + out_conv=[13], + width=64, + middle_blk_num=1, + enc_blk_nums=[1, 1, 1, 1], + dec_blk_nums=[1, 1, 1, 1], + encoder_widths=[128], + decoder_widths=[128,128,128,128,128], + out_nonlin_mean=False, + out_nonlin_var='relu', + agg_mode="att_group", + encoder_norm="group", + decoder_norm="batch", + n_head=16, + d_model=256, + d_k=4, + pad_value=0, + padding_mode="reflect", + positional_encoding=True, + covmode='diag', + scale_by=1, + separate_out=False, + use_v=False, + block_type='mbconv', + is_mono=False + ): + super().__init__() + + self.intro = nn.Conv2d(in_channels=input_dim, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + # self.cond_intro = nn.Conv2d(in_channels=img_channel+2, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + # bias=True) + self.ending = nn.Conv2d(in_channels=width, out_channels=out_conv[0], kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + # self.inp_ending = nn.Conv2d(in_channels=img_channel, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1, + # bias=True) + + self.encoders = nn.ModuleList() + self.cond_encoders = nn.ModuleList() + + self.decoders = nn.ModuleList() + + self.middle_blks = nn.ModuleList() + + self.ups = nn.ModuleList() + + self.downs = nn.ModuleList() + self.cond_downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + self.cond_encoders.append( + nn.Sequential( + *[CondNAFBlock(chan) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + # self.cond_downs.append( + # nn.Conv2d(chan, 2*chan, 2, 2) + # ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[NAFBlock(chan) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + # self.map = nn.Sequential( + # nn.Linear(64, 256), + # nn.SiLU(), + # nn.Linear(256, 256), + # ) + + def forward(self, inp, batch_positions): + # inp = self.check_image_size(inp) + inp = inp.squeeze(1) + x = self.intro(inp) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + # b, c, h, w = cond.shape + # tmp_cond = cond.view(b//3, 3, c, h, w).sum(dim=1) + # tmp_cond = cond + # x = x + tmp_cond + encs.append(x) + x = down(x) + # cond = cond_down(cond) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + x = self.ending(x) + # x = x + self.inp_ending(inp) + # print(x.shape) + return x.unsqueeze(1) + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % + self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % + self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +if __name__ == '__main__': + # unit test for ground resolution + inp = torch.randn(1, 15, 256, 256) + net = UNCRTAINTS( + input_dim=15, + out_conv=[13], + width=64, + middle_blk_num=1, + enc_blk_nums=[1, 1, 1, 1], + dec_blk_nums=[1, 1, 1, 1], + ) + out = net(inp) + assert out.shape == (1, 13, 256, 256) + + # from thop import profile + # out_shape = (1, 12, 384, 384) + # input_shape = (1, 13, 384, 384) + # model = DiffCR( + # img_channel=13, + # width=32, + # middle_blk_num=1, + # enc_blk_nums=[1, 1, 1, 1], + # dec_blk_nums=[1, 1, 1, 1], + # ) + # # 使用 thop 的 profile 函数来获取 FLOPs 和参数量 + # flops, params = profile(model, inputs=(torch.randn(out_shape), torch.ones(1,), torch.randn(input_shape))) + # print(f"FLOPs: {flops / 1e9} G") + # print(f"Parameters: {params / 1e6} M") + + + +# if __name__=='__main__': +# inp = torch.rand(1, 15, 256, 256) +# net = UNCRTAINTS( +# input_dim=15, +# out_conv=[13], +# ) +# out = net(inp) +# assert out.shape==(1, 13, 256, 256) \ No newline at end of file diff --git a/UnCRtainTS/model/src/backbones/unet3d.py b/UnCRtainTS/model/src/backbones/unet3d.py new file mode 100644 index 0000000000000000000000000000000000000000..5f95aff5564d729adec004dbe9c018e5a4ca233b --- /dev/null +++ b/UnCRtainTS/model/src/backbones/unet3d.py @@ -0,0 +1,120 @@ +""" +Taken from https://github.com/roserustowicz/crop-type-mapping/ +Implementation by the authors of the paper : +"Semantic Segmentation of crop type in Africa: A novel Dataset and analysis of deep learning methods" +R.M. Rustowicz et al. + +Slightly modified to support image sequences of varying length in the same batch. +""" + +import torch +import torch.nn as nn + + +def conv_block(in_dim, middle_dim, out_dim): + model = nn.Sequential( + nn.Conv3d(in_dim, middle_dim, kernel_size=3, stride=1, padding=1), + nn.BatchNorm3d(middle_dim), + nn.LeakyReLU(inplace=True), + nn.Conv3d(middle_dim, out_dim, kernel_size=3, stride=1, padding=1), + nn.BatchNorm3d(out_dim), + nn.LeakyReLU(inplace=True), + ) + return model + + +def center_in(in_dim, out_dim): + model = nn.Sequential( + nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), + nn.BatchNorm3d(out_dim), + nn.LeakyReLU(inplace=True)) + return model + + +def center_out(in_dim, out_dim): + model = nn.Sequential( + nn.Conv3d(in_dim, in_dim, kernel_size=3, stride=1, padding=1), + nn.BatchNorm3d(in_dim), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1)) + return model + + +def up_conv_block(in_dim, out_dim): + model = nn.Sequential( + nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.BatchNorm3d(out_dim), + nn.LeakyReLU(inplace=True), + ) + return model + + +class UNet3D(nn.Module): + def __init__(self, in_channel, n_classes, feats=8, pad_value=None, zero_pad=True, out_nonlin=False): + super(UNet3D, self).__init__() + self.in_channel = in_channel + self.n_classes = n_classes + self.pad_value = pad_value + self.zero_pad = zero_pad + + self.en3 = conv_block(in_channel, feats * 4, feats * 4) + self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) + self.en4 = conv_block(feats * 4, feats * 8, feats * 8) + self.pool_4 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) + self.center_in = center_in(feats * 8, feats * 16) + self.center_out = center_out(feats * 16, feats * 8) + self.dc4 = conv_block(feats * 16, feats * 8, feats * 8) + self.trans3 = up_conv_block(feats * 8, feats * 4) + self.dc3 = conv_block(feats * 8, feats * 4, feats * 2) + self.final = nn.Conv3d(feats * 2, n_classes, kernel_size=3, stride=1, padding=1) + if out_nonlin: + self.out_sigm = nn.Sigmoid() # this is for predicting mean values in [0, 1] + self.out_relu = nn.ReLU() # this is for predicting var values > 0 + # self.fn = nn.Linear(timesteps, 1) + # self.logsoftmax = nn.LogSoftmax(dim=1) + # self.dropout = nn.Dropout(p=dropout, inplace=True) + + def forward(self, x, batch_positions=None): + out = x.permute(0, 2, 1, 3, 4) # x was BxTxCxHxW, now BxCxTxHxW + if self.pad_value is not None: + pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=1) # BxT pad mask + if self.zero_pad: + out[out == self.pad_value] = 0 + en3 = self.en3(out) + pool_3 = self.pool_3(en3) + en4 = self.en4(pool_3) + pool_4 = self.pool_4(en4) + center_in = self.center_in(pool_4) + center_out = self.center_out(center_in) + concat4 = torch.cat([center_out, en4[:, :, :center_out.shape[2], :, :]], dim=1) + dc4 = self.dc4(concat4) + trans3 = self.trans3(dc4) + concat3 = torch.cat([trans3, en3[:, :, :trans3.shape[2], :, :]], dim=1) + dc3 = self.dc3(concat3) + final = self.final(dc3) + final = final.permute(0, 1, 3, 4, 2) # BxCxHxWxT + + # shape_num = final.shape[0:4] + # final = final.reshape(-1,final.shape[4]) + if self.pad_value is not None: + if pad_mask.any(): + # masked mean + pad_mask = pad_mask[:, :final.shape[-1]] #match new temporal length (due to pooling) + pad_mask = ~pad_mask # 0 on padded values + out = (final.permute(1, 2, 3, 0, 4) * pad_mask[None, None, None, :, :]).sum(dim=-1) / pad_mask.sum( + dim=-1)[None, None, None, :] + out = out.permute(3, 0, 1, 2) + else: + out = final.mean(dim=-1) + else: + out = final.mean(dim=-1) + if hasattr(self, 'out_sigm'): + out_mean = self.out_sigm(out[:,:,:13,...]) # mean predictions + out_std = self.out_relu(out[:,:,13:,...]) # var predictions + # stack mean and var predictions + out = torch.cat((out_mean, out_std), dim=2) + # final = self.dropout(final) + # final = self.fn(final) + # final = final.reshape(shape_num) + + return out diff --git a/UnCRtainTS/model/src/backbones/utae.py b/UnCRtainTS/model/src/backbones/utae.py new file mode 100644 index 0000000000000000000000000000000000000000..5082bc59d2bf381fd14941473c8b9f4a1d889ad6 --- /dev/null +++ b/UnCRtainTS/model/src/backbones/utae.py @@ -0,0 +1,852 @@ +""" +U-TAE Implementation +Author: Vivien Sainte Fare Garnot (github/VSainteuf) +License: MIT +""" + +import torch +import torch.nn as nn + +from src.backbones.convlstm import ConvLSTM, BConvLSTM +from src.backbones.ltae import LTAE2d, LTAE2dtiny + +# function to normalize gradient magnitudes, +# evoke via e.g. scale_gradients(out) at every forward pass +def scale_gradients(params): + def hook_norm(grad): + # get norm of parameter p's gradients + #grad_norm = p.grad.detach().data.norm(2) + # get the gradient's L2 norm + grad_norm = grad.detach().data.norm(2) + # return normalized gradient + return grad/(grad_norm+1e-9) + # see https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html + params.register_hook(hook_norm) + + +class UNet(nn.Module): + def __init__( + self, + input_dim, + encoder_widths=[64, 64, 64, 128], + decoder_widths=[32, 32, 64, 128], + out_conv=[13], + out_nonlin_mean=False, + out_nonlin_var='relu', + str_conv_k=4, + str_conv_s=2, + str_conv_p=1, + encoder_norm="group", + norm_skip="batch", + norm_up="batch", + decoder_norm="batch", + encoder=False, + return_maps=False, + pad_value=0, + padding_mode="reflect", + ): + """ + U-Net architecture for spatial pre-training of UTAE on mono-temporal data, excluding LTAE temporal encoder. + Args: + input_dim (int): Number of channels in the input images. + encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder. + This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1) + in the architecture. + The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution. + decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of + channels should be given is also from top to bottom. If this argument is not specified the decoder + will have the same configuration as the encoder. + out_conv (List[int]): Number of channels of the successive convolutions for the + str_conv_k (int): Kernel size of the strided up and down convolutions. + str_conv_s (int): Stride of the strided up and down convolutions. + str_conv_p (int): Padding of the strided up and down convolutions. + agg_mode (str): Aggregation mode for the skip connections. Can either be: + - att_group (default) : Attention weighted temporal average, using the same + channel grouping strategy as in the LTAE. The attention masks are bilinearly + resampled to the resolution of the skipped feature maps. + - att_mean : Attention weighted temporal average, + using the average attention scores across heads for each date. + - mean : Temporal average excluding padded dates. + encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be: + - group : GroupNorm (default) + - batch : BatchNorm + - instance : InstanceNorm + - none: apply no normalization + norm_skip (str): similar to encoder_norm, just controlling the normalization after convolving skipped maps + norm_up (str): similar to encoder_norm, just controlling the normalization after transposed convolution + decoder_norm (str): similar to encoder_norm + n_head (int): Number of heads in LTAE. + d_model (int): Parameter of LTAE + d_k (int): Key-Query space dimension + encoder (bool): If true, the feature maps instead of the class scores are returned (default False) + return_maps (bool): If true, the feature maps instead of the class scores are returned (default False) + pad_value (float): Value used by the dataloader for temporal padding. + padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2d). + positional_encoding (bool): If False, no positional encoding is used (default True). + """ + super(UNet, self).__init__() + self.n_stages = len(encoder_widths) + self.return_maps = return_maps + self.encoder_widths = encoder_widths + self.decoder_widths = decoder_widths + self.enc_dim = ( + decoder_widths[0] if decoder_widths is not None else encoder_widths[0] + ) + self.stack_dim = ( + sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) + ) + self.pad_value = pad_value + self.encoder = encoder + if encoder: + self.return_maps = True + + if decoder_widths is not None: + assert len(encoder_widths) == len(decoder_widths) + assert encoder_widths[-1] == decoder_widths[-1] + else: + decoder_widths = encoder_widths + + # ENCODER + self.in_conv = ConvBlock( + nkernels=[input_dim] + [encoder_widths[0]], + k=1, s=1, p=0, + pad_value=pad_value, + norm=encoder_norm, + padding_mode=padding_mode, + ) + self.down_blocks = nn.ModuleList( + DownConvBlock( + d_in=encoder_widths[i], + d_out=encoder_widths[i + 1], + k=str_conv_k, + s=str_conv_s, + p=str_conv_p, + pad_value=pad_value, + norm=encoder_norm, + padding_mode=padding_mode, + ) + for i in range(self.n_stages - 1) + ) + # DECODER + self.up_blocks = nn.ModuleList( + UpConvBlock( + d_in=decoder_widths[i], + d_out=decoder_widths[i - 1], + d_skip=encoder_widths[i - 1], + k=str_conv_k, + s=str_conv_s, + p=str_conv_p, + norm_skip=norm_skip, #'batch' + norm_up=norm_up, # 'batch' + norm=decoder_norm, #"batch", + padding_mode=padding_mode, + ) + for i in range(self.n_stages - 1, 0, -1) + ) + # note: not including normalization layer and ReLU nonlinearity into the final ConvBlock, + # if inserting >1 layers into out_conv then consider treating normalizations separately + self.out_dims = out_conv[-1] + self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, padding_mode=padding_mode, norm='none', last_relu=False) + + if out_nonlin_mean: + self.out_mean = nn.Sigmoid() # this is for predicting mean values in [0, 1] + else: + self.out_mean = nn.Identity() # just keep the mean estimates, without applying a nonlinearity + + if out_nonlin_var=='relu': + self.out_var = nn.ReLU() # this is for predicting var values > 0 + elif out_nonlin_var=='softplus': + self.out_var = nn.Softplus(beta=1, threshold=20) # a smooth approximation to the ReLU function + elif out_nonlin_var=='elu': + self.out_var = lambda vars: nn.ELU()(vars) + 1 + 1e-8 + else: # just keep the variance estimates, + self.out_var = nn.Identity() # just keep the variance estimates, without applying a nonlinearity + + def forward(self, input, batch_positions=None, return_att=False): + # SPATIAL ENCODER + # collect feature maps in list 'feature_maps' + out = self.in_conv.smart_forward(input) + feature_maps = [out] + for i in range(self.n_stages - 1): + out = self.down_blocks[i].smart_forward(feature_maps[-1]) + feature_maps.append(out) + # SPATIAL DECODER + if self.return_maps: + maps = [out] + out = out[:,0,...] # note: we index to reduce the temporal dummy dimension of size 1 + for i in range(self.n_stages - 1): + # skip-connect features between paired encoder/decoder blocks + skip = feature_maps[-(i + 2)] + # upconv the features, concatenating current 'out' and paired 'skip' + out = self.up_blocks[i](out, skip[:,0,...]) # note: we index to reduce the temporal dummy dimension of size 1 + if self.return_maps: + maps.append(out) + + if self.encoder: + return out, maps + else: + out = self.out_conv(out) + # append a singelton temporal dimension such that outputs are [B x T=1 x C x H x W] + out = out.unsqueeze(1) + # optionally apply an output nonlinearity + out_mean = self.out_mean(out[:,:,:13,...]) # mean predictions + out_std = self.out_var(out[:,:,13:,...]) # var predictions > 0 + out = torch.cat((out_mean, out_std), dim=2) # stack mean and var predictions + + if return_att: + return out, None + if self.return_maps: + return out, maps + else: + return out + + + +class UTAE(nn.Module): + def __init__( + self, + input_dim, + encoder_widths=[64, 64, 64, 128], + decoder_widths=[32, 32, 64, 128], + out_conv=[13], + out_nonlin_mean=False, + out_nonlin_var='relu', + str_conv_k=4, + str_conv_s=2, + str_conv_p=1, + agg_mode="att_group", + encoder_norm="group", + norm_skip='batch', + norm_up="batch", + decoder_norm="batch", + n_head=16, + d_model=256, + d_k=4, + encoder=False, + return_maps=False, + pad_value=0, + padding_mode="reflect", + positional_encoding=True, + scale_by=1 + ): + """ + U-TAE architecture for spatio-temporal encoding of satellite image time series. + Args: + input_dim (int): Number of channels in the input images. + encoder_widths (List[int]): List giving the number of channels of the successive encoder_widths of the convolutional encoder. + This argument also defines the number of encoder_widths (i.e. the number of downsampling steps +1) + in the architecture. + The number of channels are given from top to bottom, i.e. from the highest to the lowest resolution. + decoder_widths (List[int], optional): Same as encoder_widths but for the decoder. The order in which the number of + channels should be given is also from top to bottom. If this argument is not specified the decoder + will have the same configuration as the encoder. + out_conv (List[int]): Number of channels of the successive convolutions for the + str_conv_k (int): Kernel size of the strided up and down convolutions. + str_conv_s (int): Stride of the strided up and down convolutions. + str_conv_p (int): Padding of the strided up and down convolutions. + agg_mode (str): Aggregation mode for the skip connections. Can either be: + - att_group (default) : Attention weighted temporal average, using the same + channel grouping strategy as in the LTAE. The attention masks are bilinearly + resampled to the resolution of the skipped feature maps. + - att_mean : Attention weighted temporal average, + using the average attention scores across heads for each date. + - mean : Temporal average excluding padded dates. + encoder_norm (str): Type of normalisation layer to use in the encoding branch. Can either be: + - group : GroupNorm (default) + - batch : BatchNorm + - instance : InstanceNorm + - none: apply no normalization + norm_skip (str): similar to encoder_norm, just controlling the normalization after convolving skipped maps + norm_up (str): similar to encoder_norm, just controlling the normalization after transposed convolution + decoder_norm (str): similar to encoder_norm + n_head (int): Number of heads in LTAE. + d_model (int): Parameter of LTAE + d_k (int): Key-Query space dimension + encoder (bool): If true, the feature maps instead of the class scores are returned (default False) + return_maps (bool): If true, the feature maps instead of the class scores are returned (default False) + pad_value (float): Value used by the dataloader for temporal padding. + padding_mode (str): Spatial padding strategy for convolutional layers (passed to nn.Conv2d). + positional_encoding (bool): If False, no positional encoding is used (default True). + """ + super(UTAE, self).__init__() + self.n_stages = len(encoder_widths) + self.return_maps = return_maps + self.encoder_widths = encoder_widths + self.decoder_widths = decoder_widths + self.enc_dim = ( + decoder_widths[0] if decoder_widths is not None else encoder_widths[0] + ) + self.stack_dim = ( + sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) + ) + self.pad_value = pad_value + self.encoder = encoder + self.scale_by = scale_by + if encoder: + self.return_maps = True + + if decoder_widths is not None: + assert len(encoder_widths) == len(decoder_widths) + assert encoder_widths[-1] == decoder_widths[-1] + else: + decoder_widths = encoder_widths + + # ENCODER + self.in_conv = ConvBlock( + nkernels=[input_dim] + [encoder_widths[0]], + k=1, s=1, p=0, + pad_value=pad_value, + norm=encoder_norm, + padding_mode=padding_mode, + ) + self.down_blocks = nn.ModuleList( + DownConvBlock( + d_in=encoder_widths[i], + d_out=encoder_widths[i + 1], + k=str_conv_k, + s=str_conv_s, + p=str_conv_p, + pad_value=pad_value, + norm=encoder_norm, + padding_mode=padding_mode, + ) + for i in range(self.n_stages - 1) + ) + # DECODER + self.up_blocks = nn.ModuleList( + UpConvBlock( + d_in=decoder_widths[i], + d_out=decoder_widths[i - 1], + d_skip=encoder_widths[i - 1], + k=str_conv_k, + s=str_conv_s, + p=str_conv_p, + norm_skip=norm_skip, # 'batch' + norm_up=norm_up, # 'batch' + norm=decoder_norm, #"batch", + padding_mode=padding_mode, + ) + for i in range(self.n_stages - 1, 0, -1) + ) + # LTAE + self.temporal_encoder = LTAE2d( + in_channels=encoder_widths[-1], + d_model=d_model, + n_head=n_head, + mlp=[d_model, encoder_widths[-1]], + return_att=True, + d_k=d_k, + positional_encoding=positional_encoding, + ) + self.temporal_aggregator = Temporal_Aggregator(mode=agg_mode) + # note: not including normalization layer and ReLU nonlinearity into the final ConvBlock + # if inserting >1 layers into out_conv then consider treating normalizations separately + self.out_dims = out_conv[-1] + self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, padding_mode=padding_mode, norm='none', last_relu=False) + if out_nonlin_mean: + self.out_mean = lambda vars: self.scale_by * nn.Sigmoid()(vars) # this is for predicting mean values in [0, 1] + else: + self.out_mean = lambda vars: nn.Identity()(vars) # just keep the mean estimates, without applying a nonlinearity + + if out_nonlin_var=='relu': + self.out_var = nn.ReLU() # this is for predicting var values > 0 + elif out_nonlin_var=='softplus': + self.out_var = nn.Softplus(beta=1, threshold=20) # a smooth approximation to the ReLU function + elif out_nonlin_var=='elu': + self.out_var = lambda vars: nn.ELU()(vars) + 1 + 1e-8 + else: # just keep the variance estimates, + self.out_var = nn.Identity() # just keep the variance estimates, without applying a nonlinearity + + def forward(self, input, batch_positions=None, return_att=False): + pad_mask = ( + (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + ) # BxT pad mask + # SPATIAL ENCODER + # collect feature maps in list 'feature_maps' + out = self.in_conv.smart_forward(input) + feature_maps = [out] + for i in range(self.n_stages - 1): + out = self.down_blocks[i].smart_forward(feature_maps[-1]) + feature_maps.append(out) + # TEMPORAL ENCODER + # feature_maps[-1].shape is torch.Size([B, T, 128, 32, 32]) + # -> every attention pixel has an 8x8 receptive field + # att.shape is torch.Size([h, B, T, 32, 32]) + # out.shape is torch.Size([B, 128, 32, 32]), in self-attention class it's Size([B*32*32*h=32768, 1, 16] + out, att = self.temporal_encoder( + feature_maps[-1], batch_positions=batch_positions, pad_mask=pad_mask + ) + # SPATIAL DECODER + if self.return_maps: + maps = [out] + for i in range(self.n_stages - 1): + skip = self.temporal_aggregator( + feature_maps[-(i + 2)], pad_mask=pad_mask, attn_mask=att + ) + out = self.up_blocks[i](out, skip) + if self.return_maps: + maps.append(out) + + if self.encoder: + return out, maps + else: + out = self.out_conv(out) + # append a singelton temporal dimension such that outputs are [B x T=1 x C x H x W] + out = out.unsqueeze(1) + # optionally apply an output nonlinearity + out_mean = self.out_mean(out[:,:,:13,...]) # mean predictions + out_std = self.out_var(out[:,:,13:,...]) # var predictions > 0 + out = torch.cat((out_mean, out_std), dim=2) # stack mean and var predictions + + if return_att: + return out, att + if self.return_maps: + return out, maps + else: + return out + + +class TemporallySharedBlock(nn.Module): + """ + Helper module for convolutional encoding blocks that are shared across a sequence. + This module adds the self.smart_forward() method the the block. + smart_forward will combine the batch and temporal dimension of an input tensor + if it is 5-D and apply the shared convolutions to all the (batch x temp) positions. + """ + + def __init__(self, pad_value=None): + super(TemporallySharedBlock, self).__init__() + self.out_shape = None + self.pad_value = pad_value + + def smart_forward(self, input): + if len(input.shape) == 4: + return self.forward(input) + else: + b, t, c, h, w = input.shape + + if self.pad_value is not None: + dummy = torch.zeros(input.shape, device=input.device).float() + self.out_shape = self.forward(dummy.view(b * t, c, h, w)).shape + + out = input.view(b * t, c, h, w) + if self.pad_value is not None: + pad_mask = (out == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + if pad_mask.any(): + temp = ( + torch.ones( + self.out_shape, device=input.device, requires_grad=False + ) + * self.pad_value + ) + temp[~pad_mask] = self.forward(out[~pad_mask]) + out = temp + else: + out = self.forward(out) + else: + out = self.forward(out) + _, c, h, w = out.shape + out = out.view(b, t, c, h, w) + return out + + +class ConvLayer(nn.Module): + def __init__( + self, + nkernels, + norm="batch", + k=3, s=1, p=1, + n_groups=4, + last_relu=True, + padding_mode="reflect", + ): + super(ConvLayer, self).__init__() + layers = [] + if norm == "batch": + nl = nn.BatchNorm2d + elif norm == "instance": + nl = nn.InstanceNorm2d + elif norm == "group": + nl = lambda num_feats: nn.GroupNorm( + num_channels=num_feats, + num_groups=n_groups, + ) + else: + nl = None + for i in range(len(nkernels) - 1): + layers.append( + nn.Conv2d( + in_channels=nkernels[i], + out_channels=nkernels[i + 1], + kernel_size=k, + padding=p, + stride=s, + padding_mode=padding_mode, + ) + ) + if nl is not None: + layers.append(nl(nkernels[i + 1])) + + if last_relu: # append a ReLU after the current CONV layer + layers.append(nn.ReLU()) + elif i < len(nkernels) - 2: # only append ReLU if not last layer + layers.append(nn.ReLU()) + self.conv = nn.Sequential(*layers) + + def forward(self, input): + return self.conv(input) + + +class ConvBlock(TemporallySharedBlock): + def __init__( + self, + nkernels, + pad_value=None, + norm="batch", + last_relu=True, + k=3, s=1, p=1, + padding_mode="reflect", + ): + super(ConvBlock, self).__init__(pad_value=pad_value) + self.conv = ConvLayer( + nkernels=nkernels, + norm=norm, + last_relu=last_relu, + k=k, s=s, p=p, + padding_mode=padding_mode, + ) + + def forward(self, input): + return self.conv(input) + + +class DownConvBlock(TemporallySharedBlock): + def __init__( + self, + d_in, + d_out, + k, s, p, + pad_value=None, + norm="batch", + padding_mode="reflect", + ): + super(DownConvBlock, self).__init__(pad_value=pad_value) + self.down = ConvLayer( + nkernels=[d_in, d_in], + norm=norm, + k=k, s=s, p=p, + padding_mode=padding_mode, + ) + self.conv1 = ConvLayer( + nkernels=[d_in, d_out], + norm=norm, + padding_mode=padding_mode, + ) + self.conv2 = ConvLayer( + nkernels=[d_out, d_out], + norm=norm, + padding_mode=padding_mode, + last_relu=False # note: removing last ReLU in DownConvBlock because it adds onto residual connection + ) + + def forward(self, input): + out = self.down(input) + out = self.conv1(out) + out = out + self.conv2(out) + return out + + +def get_norm_layer(out_channels, num_feats, n_groups=4, layer_type='BatchNorm'): + if layer_type == 'batch': + return nn.BatchNorm2d(out_channels) + elif layer_type == 'instance': + return nn.InstanceNorm2d(out_channels) + elif layer_type == 'group': + return nn.GroupNorm(num_channels=num_feats, num_groups=n_groups) + +class UpConvBlock(nn.Module): + def __init__(self, d_in, d_out, k, s, p, norm_skip="batch", norm_up ="batch", norm="batch", n_groups=4, d_skip=None, padding_mode="reflect"): + super(UpConvBlock, self).__init__() + d = d_out if d_skip is None else d_skip + + # apply another CONV and norm to the skipped paired map + """" + self.skip_conv = nn.Sequential( + nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1), + nn.BatchNorm2d(d), + nn.ReLU(), + ) + """ + if norm_skip in ['group', 'batch', 'instance']: + self.skip_conv = nn.Sequential( + nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1), + get_norm_layer(d, d, n_groups, norm_skip), #nn.BatchNorm2d(d), + nn.ReLU()) + else: + self.skip_conv = nn.Sequential( + nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1), + nn.ReLU()) + + # transposed CONV layer to perform upsampling + """ + self.up = nn.Sequential( + nn.ConvTranspose2d( + in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p + ), + nn.BatchNorm2d(d_out), + nn.ReLU(), + ) + """ + if norm_up in ['group', 'batch', 'instance']: + self.up = nn.Sequential( + nn.ConvTranspose2d(in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p), + get_norm_layer(d_out, d_out, n_groups, norm_up), #nn.BatchNorm2d(d_out), + nn.ReLU()) + else: + self.up = nn.Sequential( + nn.ConvTranspose2d(in_channels=d_in, out_channels=d_out, kernel_size=k, stride=s, padding=p), + nn.ReLU()) + + self.conv1 = ConvLayer( + nkernels=[d_out + d, d_out], norm=norm, padding_mode=padding_mode, # removing downsampling relu in UpConvBlock because of MobileNet2 + ) + self.conv2 = ConvLayer( + nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode, last_relu=False # removing last relu in UpConvBlock because it adds onto residual connection + ) + + def forward(self, input, skip): + out = self.up(input) # transposed CONV on previous layer + # apply another CONV and norm to the skipped input --> paired encoder map + out = torch.cat([out, self.skip_conv(skip)], dim=1) # concat '' with paired encoder map + out = self.conv1(out) # CONV again + out = out + self.conv2(out) # conv with residual + return out + + +class Temporal_Aggregator(nn.Module): + def __init__(self, mode="mean"): + super(Temporal_Aggregator, self).__init__() + self.mode = mode + + def forward(self, x, pad_mask=None, attn_mask=None): + if pad_mask is not None and pad_mask.any(): + if self.mode == "att_group": + n_heads, b, t, h, w = attn_mask.shape + attn = attn_mask.view(n_heads * b, t, h, w) + + if x.shape[-2] > w: + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + else: + attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn) + + attn = attn.view(n_heads, b, t, *x.shape[-2:]) + attn = attn * (~pad_mask).float()[None, :, :, None, None] + + out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW + out = attn[:, :, :, None, :, :] * out + out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW + out = torch.cat([group for group in out], dim=1) # -> BxCxHxW + return out + elif self.mode == "att_mean": + attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + attn = attn * (~pad_mask).float()[:, :, None, None] + out = (x * attn[:, :, None, :, :]).sum(dim=1) + return out + elif self.mode == "mean": + out = x * (~pad_mask).float()[:, :, None, None, None] + out = out.sum(dim=1) / (~pad_mask).sum(dim=1)[:, None, None, None] + return out + else: + if self.mode == "att_group": + n_heads, b, t, h, w = attn_mask.shape + attn = attn_mask.view(n_heads * b, t, h, w) + if x.shape[-2] > w: + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + else: + attn = nn.AvgPool2d(kernel_size=w // x.shape[-2])(attn) + attn = attn.view(n_heads, b, t, *x.shape[-2:]) + out = torch.stack(x.chunk(n_heads, dim=2)) # hxBxTxC/hxHxW + out = attn[:, :, :, None, :, :] * out + out = out.sum(dim=2) # sum on temporal dim -> hxBxC/hxHxW + out = torch.cat([group for group in out], dim=1) # -> BxCxHxW + return out + elif self.mode == "att_mean": + attn = attn_mask.mean(dim=0) # average over heads -> BxTxHxW + attn = nn.Upsample( + size=x.shape[-2:], mode="bilinear", align_corners=False + )(attn) + out = (x * attn[:, :, None, :, :]).sum(dim=1) + return out + elif self.mode == "mean": + return x.mean(dim=1) + + +class RecUNet(nn.Module): + """Recurrent U-Net architecture. Similar to the U-TAE architecture but + the L-TAE is replaced by a recurrent network + and temporal averages are computed for the skip connections.""" + + def __init__( + self, + input_dim, + encoder_widths=[64, 64, 64, 128], + decoder_widths=[32, 32, 64, 128], + out_conv=[13], + str_conv_k=4, + str_conv_s=2, + str_conv_p=1, + temporal="lstm", + input_size=128, + encoder_norm="group", + hidden_dim=128, + encoder=False, + padding_mode="reflect", + pad_value=0, + ): + super(RecUNet, self).__init__() + self.n_stages = len(encoder_widths) + self.temporal = temporal + self.encoder_widths = encoder_widths + self.decoder_widths = decoder_widths + self.enc_dim = ( + decoder_widths[0] if decoder_widths is not None else encoder_widths[0] + ) + self.stack_dim = ( + sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths) + ) + self.pad_value = pad_value + + self.encoder = encoder + if encoder: + self.return_maps = True + else: + self.return_maps = False + + if decoder_widths is not None: + assert len(encoder_widths) == len(decoder_widths) + assert encoder_widths[-1] == decoder_widths[-1] + else: + decoder_widths = encoder_widths + + self.in_conv = ConvBlock( + nkernels=[input_dim] + [encoder_widths[0], encoder_widths[0]], + pad_value=pad_value, + norm=encoder_norm, + ) + + self.down_blocks = nn.ModuleList( + DownConvBlock( + d_in=encoder_widths[i], + d_out=encoder_widths[i + 1], + k=str_conv_k, + s=str_conv_s, + p=str_conv_p, + pad_value=pad_value, + norm=encoder_norm, + padding_mode=padding_mode, + ) + for i in range(self.n_stages - 1) + ) + self.up_blocks = nn.ModuleList( + UpConvBlock( + d_in=decoder_widths[i], + d_out=decoder_widths[i - 1], + d_skip=encoder_widths[i - 1], + k=str_conv_k, + s=str_conv_s, + p=str_conv_p, + norm=encoder_norm, + padding_mode=padding_mode, + ) + for i in range(self.n_stages - 1, 0, -1) + ) + self.temporal_aggregator = Temporal_Aggregator(mode="mean") + + if temporal == "mean": + self.temporal_encoder = Temporal_Aggregator(mode="mean") + elif temporal == "lstm": + size = int(input_size / str_conv_s ** (self.n_stages - 1)) + self.temporal_encoder = ConvLSTM( + input_dim=encoder_widths[-1], + input_size=(size, size), + hidden_dim=hidden_dim, + kernel_size=(3, 3), + ) + self.out_convlstm = nn.Conv2d( + in_channels=hidden_dim, + out_channels=encoder_widths[-1], + kernel_size=3, + padding=1, + ) + elif temporal == "blstm": + size = int(input_size / str_conv_s ** (self.n_stages - 1)) + self.temporal_encoder = BConvLSTM( + input_dim=encoder_widths[-1], + input_size=(size, size), + hidden_dim=hidden_dim, + kernel_size=(3, 3), + ) + self.out_convlstm = nn.Conv2d( + in_channels=2 * hidden_dim, + out_channels=encoder_widths[-1], + kernel_size=3, + padding=1, + ) + elif temporal == "mono": + self.temporal_encoder = None + self.out_conv = ConvBlock(nkernels=[decoder_widths[0]] + out_conv, k=1, s=1, p=0, padding_mode=padding_mode) + + def forward(self, input, batch_positions=None): + pad_mask = ( + (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1) + ) # BxT pad mask + + out = self.in_conv.smart_forward(input) + + feature_maps = [out] + # ENCODER + for i in range(self.n_stages - 1): + out = self.down_blocks[i].smart_forward(feature_maps[-1]) + feature_maps.append(out) + + # Temporal encoder + if self.temporal == "mean": + out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask) + elif self.temporal == "lstm": + _, out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask) + out = out[0][1] # take last cell state as embedding + out = self.out_convlstm(out) + elif self.temporal == "blstm": + out = self.temporal_encoder(feature_maps[-1], pad_mask=pad_mask) + out = self.out_convlstm(out) + elif self.temporal == "mono": + out = feature_maps[-1] + + if self.return_maps: + maps = [out] + for i in range(self.n_stages - 1): + if self.temporal != "mono": + skip = self.temporal_aggregator( + feature_maps[-(i + 2)], pad_mask=pad_mask + ) + else: + skip = feature_maps[-(i + 2)] + out = self.up_blocks[i](out, skip) + if self.return_maps: + maps.append(out) + + if self.encoder: + return out, maps + else: + out = self.out_conv(out) + if self.return_maps: + return out, maps + else: + return out \ No newline at end of file diff --git a/UnCRtainTS/model/src/learning/__pycache__/metrics.cpython-311.pyc b/UnCRtainTS/model/src/learning/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a677c8832a0a3b7de00b8c3ba343e1b462557b07 Binary files /dev/null and b/UnCRtainTS/model/src/learning/__pycache__/metrics.cpython-311.pyc differ diff --git a/UnCRtainTS/model/src/learning/__pycache__/weight_init.cpython-311.pyc b/UnCRtainTS/model/src/learning/__pycache__/weight_init.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fbed3e0013183c52557d691332b89a20a199df9 Binary files /dev/null and b/UnCRtainTS/model/src/learning/__pycache__/weight_init.cpython-311.pyc differ diff --git a/UnCRtainTS/model/src/learning/metrics.py b/UnCRtainTS/model/src/learning/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b55665bde5ed96735b7a0c3dea66b63bf9b639f0 --- /dev/null +++ b/UnCRtainTS/model/src/learning/metrics.py @@ -0,0 +1,101 @@ +import os +import sys +import torch +import numpy as np + +sys.path.append(os.path.dirname(os.getcwd())) +sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) +from util import pytorch_ssim + +class Metric(object): + """Base class for all metrics. + From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py + """ + + def reset(self): pass + def add(self): pass + def value(self): pass + + +def img_metrics(target, pred, var=None, pixelwise=True): + rmse = torch.sqrt(torch.mean(torch.square(target - pred))) + psnr = 20 * torch.log10(1 / rmse) + mae = torch.mean(torch.abs(target - pred)) + + # spectral angle mapper + mat = target * pred + mat = torch.sum(mat, 1) + mat = torch.div(mat, torch.sqrt(torch.sum(target * target, 1))) + mat = torch.div(mat, torch.sqrt(torch.sum(pred * pred, 1))) + sam = torch.mean(torch.acos(torch.clamp(mat, -1, 1))*torch.tensor(180)/torch.pi) + + ssim = pytorch_ssim.ssim(target, pred) + + metric_dict = {'RMSE': rmse.cpu().numpy().item(), + 'MAE': mae.cpu().numpy().item(), + 'PSNR': psnr.cpu().numpy().item(), + 'SAM': sam.cpu().numpy().item(), + 'SSIM': ssim.cpu().numpy().item()} + + # evaluate the (optional) variance maps + if var is not None: + error = target - pred + # average across the spectral dimensions + se = torch.square(error) + ae = torch.abs(error) + + # collect sample-wise error, AE, SE and uncertainties + + # define a sample as 1 image and provide image-wise statistics + errvar_samplewise = {'error': error.nanmean().cpu().numpy().item(), + 'mean ae': ae.nanmean().cpu().numpy().item(), + 'mean se': se.nanmean().cpu().numpy().item(), + 'mean var': var.nanmean().cpu().numpy().item()} + if pixelwise: + # define a sample as 1 multivariate pixel and provide image-wise statistics + errvar_samplewise = {**errvar_samplewise, **{'pixelwise error': error.nanmean(0).nanmean(0).flatten().cpu().numpy(), + 'pixelwise ae': ae.nanmean(0).nanmean(0).flatten().cpu().numpy(), + 'pixelwise se': se.nanmean(0).nanmean(0).flatten().cpu().numpy(), + 'pixelwise var': var.nanmean(0).nanmean(0).flatten().cpu().numpy()}} + + metric_dict = {**metric_dict, **errvar_samplewise} + + return metric_dict + +class avg_img_metrics(Metric): + def __init__(self): + super().__init__() + self.n_samples = 0 + self.metrics = ['RMSE', 'MAE', 'PSNR','SAM','SSIM'] + self.metrics += ['error', 'mean se', 'mean ae', 'mean var'] + + self.running_img_metrics = {} + self.running_nonan_count = {} + self.reset() + + def reset(self): + for metric in self.metrics: + self.running_nonan_count[metric] = 0 + self.running_img_metrics[metric] = np.nan + + def add(self, metrics_dict): + for key, val in metrics_dict.items(): + # skip variables not registered + if key not in self.metrics: continue + # filter variables not translated to numpy yet + if torch.is_tensor(val): continue + if isinstance(val, tuple): val=val[0] + + # only keep a running mean of non-nan values + if np.isnan(val): continue + + if not self.running_nonan_count[key]: + self.running_nonan_count[key] = 1 + self.running_img_metrics[key] = val + else: + self.running_nonan_count[key]+= 1 + self.running_img_metrics[key] = (self.running_nonan_count[key]-1)/self.running_nonan_count[key] * self.running_img_metrics[key] \ + + 1/self.running_nonan_count[key] * val + + def value(self): + return self.running_img_metrics \ No newline at end of file diff --git a/UnCRtainTS/model/src/learning/weight_init.py b/UnCRtainTS/model/src/learning/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..019d7e601f0188a8fdbeb1e908cf6c07ed22c4dd --- /dev/null +++ b/UnCRtainTS/model/src/learning/weight_init.py @@ -0,0 +1,75 @@ +import torch.nn as nn +import torch.nn.init as init + +def weight_init(m, spread=1.0): + ''' + Initializes a model's parameters. + Credits to: https://gist.github.com/jeasinema + + Usage: + model = Model() + model.apply(weight_init) + ''' + if isinstance(m, nn.Conv1d): + init.normal_(m.weight.data, mean=0, std=spread) + if m.bias is not None: + init.normal_(m.bias.data, mean=0, std=spread) + elif isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight.data, gain=spread) + if m.bias is not None: + init.normal_(m.bias.data, mean=0, std=spread) + elif isinstance(m, nn.Conv3d): + init.xavier_normal_(m.weight.data, gain=spread) + if m.bias is not None: + init.normal_(m.bias.data, mean=0, std=spread) + elif isinstance(m, nn.ConvTranspose1d): + init.normal_(m.weight.data, mean=0, std=spread) + if m.bias is not None: + init.normal_(m.bias.data, mean=0, std=spread) + elif isinstance(m, nn.ConvTranspose2d): + init.xavier_normal_(m.weight.data, gain=spread) + if m.bias is not None: + init.normal_(m.bias.data, mean=0, std=spread) + elif isinstance(m, nn.ConvTranspose3d): + init.xavier_normal_(m.weight.data, gain=spread) + if m.bias is not None: + init.normal_(m.bias.data, mean=0, std=spread) + elif isinstance(m, nn.BatchNorm1d): + init.normal_(m.weight.data, mean=0, std=spread) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.BatchNorm2d): + init.normal_(m.weight.data, mean=0, std=spread) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.BatchNorm3d): + init.normal_(m.weight.data, mean=0, std=spread) + init.constant_(m.bias.data, 0) + elif isinstance(m, nn.Linear): + init.xavier_normal_(m.weight.data, gain=spread) + try: + init.normal_(m.bias.data, mean=0, std=spread) + except AttributeError: + pass + elif isinstance(m, nn.LSTM): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data, mean=0, std=spread) + elif isinstance(m, nn.LSTMCell): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data, mean=0, std=spread) + elif isinstance(m, nn.GRU): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data, mean=0, std=spread) + elif isinstance(m, nn.GRUCell): + for param in m.parameters(): + if len(param.shape) >= 2: + init.orthogonal_(param.data) + else: + init.normal_(param.data, mean=0, std=spread) \ No newline at end of file diff --git a/UnCRtainTS/model/src/losses.py b/UnCRtainTS/model/src/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..29e0f98483cdfff1a7334ebad496dab4396e5711 --- /dev/null +++ b/UnCRtainTS/model/src/losses.py @@ -0,0 +1,354 @@ +import math +import torch +Tensor = torch.Tensor +import torch.nn as nn +import torch.nn.modules.loss +from torch.nn.modules.loss import _Loss +from torch.overrides import has_torch_function_variadic, handle_torch_function + +from torch import vmap + +S2_BANDS = 13 + + +def get_loss(config): + if config.loss == "GNLL": + criterion1 = GaussianNLLLoss(reduction='mean', eps=1e-8, full=True) + criterion = lambda pred, targ, var: criterion1(pred, targ, var) + elif config.loss == "MGNLL": + criterion1 = MultiGaussianNLLLoss(reduction='mean', eps=1e-8, full=True, mode=config.covmode, chunk=config.chunk_size) + criterion = lambda pred, targ, var: criterion1(pred, targ, var) + elif config.loss=="l1": + criterion1 = nn.L1Loss() + criterion = lambda pred, targ: criterion1(pred, targ) + elif config.loss=="l2": + criterion1 = nn.MSELoss() + criterion = lambda pred, targ: criterion1(pred, targ) + else: raise NotImplementedError + + # wrap losses + loss_wrap = lambda *args: args + loss = loss_wrap(criterion) + return loss if not isinstance(loss, tuple) else loss[0] + + +def calc_loss(criterion, config, out, y, var=None): + + if config.loss in ['GNLL']: + loss, variance = criterion(out, y, var) + elif config.loss in ['MGNLL']: + loss, variance = criterion(out, y, var) + else: + loss, variance = criterion(out, y), None + return loss, variance + + +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: bool = False, + eps: float = 1e-8, + reduction: str = "mean", +) -> Tensor: + r"""Gaussian negative log likelihood loss. + + based on :class:`~torch.nn.GaussianNLLLoss` for details. + + Args: + input: expectation of the Gaussian distribution. + target: sample from the Gaussian distribution. + var: tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic). + full (bool, optional): include the constant term in the loss calculation. Default: ``False``. + eps (float, optional): value added to var, for stability. Default: 1e-6. + reduction (string, optional): specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output is the average of all batch member losses, + ``'sum'``: the output is the sum of all batch member losses. + Default: ``'mean'``. + """ + if has_torch_function_variadic(input, target, var): + return handle_torch_function( + gaussian_nll_loss, + (input, target, var), + input, + target, + var, + full=full, + eps=eps, + reduction=reduction, + ) + + # Check var size + # If var.size == input.size, the case is heteroscedastic and no further checks are needed. + # Otherwise: + if var.size() != input.size(): + + # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2) + # -> unsqueeze var so that var.shape = (10, 2, 1) + # this is done so that broadcasting can happen in the loss calculation + if input.size()[:-1] == var.size(): + var = torch.unsqueeze(var, dim=-1) + + # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. + # This is also a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) + elif input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1: # Heteroscedastic case + pass + + # If none of the above pass, then the size of var is incorrect. + else: + raise ValueError("var is of incorrect size") + + # Check validity of reduction mode + if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + if torch.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Clamp for stability + var = var.clone() + with torch.no_grad(): + var.clamp_(min=eps) + + # Calculate the loss + loss = 0.5 * (torch.log(var) + (input - target)**2 / var) + if full: + loss += 0.5 * math.log(2 * math.pi) + + if reduction == 'mean': + return loss.mean(), var + elif reduction == 'sum': + return loss.sum(), var + else: + return loss, var + + +def multi_diag_gaussian_nll(pred, target, var): + # maps var from [B x 1 x C] to [B x 1 x C x C] + pred, target, var = pred.squeeze(dim=1), target.squeeze(dim=1), var.squeeze(dim=1) + + k = pred.shape[-1] + prec = torch.diag_embed(1/var, offset=0, dim1=-2, dim2=-1) + # the log-determinant of a diagonal matrix is simply the trace of the log of the diagonal matrix + logdetv = var.log().sum() # this may be more numerically stable a general calculation + err = (pred - target).unsqueeze(dim=1) + # for the Mahalanobis distance xTCx to be defined and >= 0, the precision matrix must be positive definite + xTCx = torch.bmm(torch.bmm(err, prec), err.permute(0, 2, 1)).squeeze().nan_to_num().clamp(min=1e-9) # note: equals torch.bmm(torch.bmm(-err, prec), -err) + # define the NLL loss + loss = -(-k/2 * torch.log(2*torch.tensor(torch.pi)) - 1/2 * logdetv - 1/2 * xTCx) + + return loss, torch.diag_embed(var, offset=0, dim1=-2, dim2=-1).cpu() + + + +def multi_gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: bool = False, + eps: float = 1e-8, + reduction: str = "mean", + mode: str = "diag", + chunk = None +) -> Tensor: + r"""Multivariate Gaussian negative log likelihood loss. + + based on :class:`~torch.nn.GaussianNLLLoss` for details. + + Args: + input: expectation of the Gaussian distribution. + target: sample from the Gaussian distribution. + var: tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic). + full (bool, optional): include the constant term in the loss calculation. Default: ``False``. + eps (float, optional): value added to var, for stability. Default: 1e-6. + reduction (string, optional): specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output is the average of all batch member losses, + ``'sum'``: the output is the sum of all batch member losses. + Default: ``'mean'``. + """ + if has_torch_function_variadic(input, target, var): + return handle_torch_function( + multi_gaussian_nll_loss, + (input, target, var), + input, + target, + var, + full=full, + eps=eps, + reduction=reduction, + mode=mode, + chunk=None + ) + + if mode=='iso': + # duplicate the scalar variance across all spectral dimensions + var = var.expand(-1,-1,S2_BANDS,-1,-1) + + # Check validity of reduction mode + if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + if torch.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Clamp for stability + var = var.clone() + with torch.no_grad(): + var[:,:,:S2_BANDS].clamp_(min=eps) + + if mode in ['iso', 'diag']: + mapdims = (-1,-1,-1) + loss, variance = vmap(vmap(multi_diag_gaussian_nll, in_dims=mapdims, chunk_size=chunk), in_dims=mapdims, chunk_size=chunk)(input, target, var) + + variance = variance.moveaxis(1,-1).moveaxis(0,-1).unsqueeze(1) + + if reduction == 'mean': + return loss.mean(), variance + elif reduction == 'sum': + return loss.sum(), variance + else: + return loss, variance + + + +class GaussianNLLLoss(_Loss): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (string, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + __constants__ = ['full', 'eps', 'reduction'] + full: bool + eps: float + + def __init__(self, *, full: bool = False, eps: float = 1e-8, reduction: str = 'mean') -> None: + super(GaussianNLLLoss, self).__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: + return gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction) + + + + +class MultiGaussianNLLLoss(_Loss): + r"""Multivariate Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (string, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Latent: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + __constants__ = ['full', 'eps', 'reduction'] + full: bool + eps: float + + def __init__(self, *, full: bool = False, eps: float = 1e-8, reduction: str = 'mean', mode: str = 'diag', chunk: None) -> None: + super(MultiGaussianNLLLoss, self).__init__(None, None, reduction) + self.full = full + self.eps = eps + self.mode = mode + self.chunk = chunk + + def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: + return multi_gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction, mode=self.mode, chunk=self.chunk) diff --git a/UnCRtainTS/model/src/model_utils.py b/UnCRtainTS/model/src/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3fcfddbaca9009824754f26de7e65cf1c7070588 --- /dev/null +++ b/UnCRtainTS/model/src/model_utils.py @@ -0,0 +1,232 @@ +import os +import torch + +sub_dir = os.path.join(os.getcwd(), 'model') +if os.path.isdir(sub_dir): os.chdir(sub_dir) +from src.backbones import base_model, utae, uncrtaints + +S1_BANDS = 2 +S2_BANDS = 13 + +def get_base_model(config): + model = base_model.BaseModel(config) + return model + +# for running image reconstruction +def get_generator(config): + if "unet" in config.model: + model = utae.UNet( + input_dim=S1_BANDS*config.use_sar+S2_BANDS, + encoder_widths=config.encoder_widths, + decoder_widths=config.decoder_widths, + out_conv=config.out_conv, + out_nonlin_mean=config.mean_nonLinearity, + out_nonlin_var=config.var_nonLinearity, + str_conv_k=4, + str_conv_s=2, + str_conv_p=1, + encoder_norm=config.encoder_norm, + norm_skip='batch', + norm_up='batch', + decoder_norm=config.decoder_norm, + encoder=False, + return_maps=False, + pad_value=config.pad_value, + padding_mode=config.padding_mode, + ) + elif "utae" in config.model: + if config.pretrain: + # on monotemporal data, just use a simple U-Net + model = utae.UNet( + input_dim=S1_BANDS*config.use_sar+S2_BANDS, + encoder_widths=config.encoder_widths, + decoder_widths=config.decoder_widths, + out_conv=config.out_conv, + out_nonlin_mean=config.mean_nonLinearity, + out_nonlin_var=config.var_nonLinearity, + str_conv_k=4, + str_conv_s=2, + str_conv_p=1, + encoder_norm=config.encoder_norm, + norm_skip='batch', + norm_up='batch', + decoder_norm=config.decoder_norm, + encoder=False, + return_maps=False, + pad_value=config.pad_value, + padding_mode=config.padding_mode, + ) + else: + model = utae.UTAE( + input_dim=S1_BANDS*config.use_sar+S2_BANDS, + encoder_widths=config.encoder_widths, + decoder_widths=config.decoder_widths, + out_conv=config.out_conv, + out_nonlin_mean=config.mean_nonLinearity, + out_nonlin_var=config.var_nonLinearity, + str_conv_k=4, + str_conv_s=2, + str_conv_p=1, + agg_mode=config.agg_mode, + encoder_norm=config.encoder_norm, + norm_skip='batch', + norm_up='batch', + decoder_norm=config.decoder_norm, + n_head=config.n_head, + d_model=config.d_model, + d_k=config.d_k, + encoder=False, + return_maps=False, + pad_value=config.pad_value, + padding_mode=config.padding_mode, + positional_encoding=config.positional_encoding, + scale_by=config.scale_by + ) + elif 'uncrtaints' == config.model: + model = uncrtaints.UNCRTAINTS( + input_dim=S1_BANDS*config.use_sar+S2_BANDS, + encoder_widths=config.encoder_widths, + decoder_widths=config.decoder_widths, + out_conv=config.out_conv, + out_nonlin_mean=config.mean_nonLinearity, + out_nonlin_var=config.var_nonLinearity, + agg_mode=config.agg_mode, + encoder_norm=config.encoder_norm, + decoder_norm=config.decoder_norm, + n_head=config.n_head, + d_model=config.d_model, + d_k=config.d_k, + pad_value=config.pad_value, + padding_mode=config.padding_mode, + positional_encoding=config.positional_encoding, + covmode=config.covmode, + scale_by=config.scale_by, + separate_out=config.separate_out, + use_v=config.use_v, + block_type=config.block_type, + is_mono=config.pretrain + ) + else: raise NotImplementedError + return model + + +def get_model(config): + return get_base_model(config) + + +def save_model(config, epoch, model, name): + state_dict = {"epoch": epoch, + "state_dict": model.state_dict(), + "state_dict_G": model.netG.state_dict(), + "optimizer_G": model.optimizer_G.state_dict(), + "scheduler_G": model.scheduler_G.state_dict()} + torch.save(state_dict, + os.path.join(config.res_dir, config.experiment_name, f"{name}.pth.tar"), + ) + + +def load_model(config, model, train_out_layer=True, load_out_partly=True): + # load pre-trained checkpoints, but only of matching weigths + + pretrained_dict = torch.load(config.trained_checkp, map_location=config.device)["state_dict_G"] + model_dict = model.netG.state_dict() + + not_str = "" if pretrained_dict.keys() == model_dict.keys() else "not " + print(f'The new and the (pre-)trained model architectures are {not_str}identical.\n') + + try:# try loading checkpoint strictly, all weights must match + # (this is satisfied e.g. when resuming training) + + if train_out_layer: raise NotImplementedError # move to 'except' case + model.netG.load_state_dict(pretrained_dict, strict=True) + freeze_layers(model.netG, grad=True) # set all weights to trainable, no need to freeze + model.frozen, freeze_these = False, [] # ... as all weights match appropriately + except: # if some weights don't match (e.g. when loading from pre-trained U-Net), then only load the compatible subset ... + # ... freeze compatible weights and make the incompatibel weights trainable + + # load output layer partly, e.g. when pretrained net has 3 output channels but novel model has 13 + if load_out_partly: + # overwrite output layer even when dimensions mismatch (this overwrites kernels individually) + #""" # these lines were used for predicting the 13 mean bands when mean and var shared a single output layer + temp_weights, temp_biases = model_dict['out_conv.conv.conv.0.weight'], model_dict['out_conv.conv.conv.0.bias'] + temp_weights[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...] + temp_biases[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...] + pretrained_dict['out_conv.conv.conv.0.weight'] = temp_weights[:S2_BANDS,...] + pretrained_dict['out_conv.conv.conv.0.bias'] = temp_biases[:S2_BANDS,...] + """ + if 'out_conv.conv.conv.0.weight' in pretrained_dict: # if predicting from a model with a single output layer for both mean and var + pretrained_dict['out_conv_mean.conv.conv.0.weight'] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...] + pretrained_dict['out_conv_mean.conv.conv.0.bias'] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...] + if 'out_conv_var.conv.conv.0.weight' in model_dict: + pretrained_dict['out_conv_var.conv.conv.0.weight'] = model_dict['out_conv_var.conv.conv.0.weight'] + pretrained_dict['out_conv_var.conv.conv.0.bias'] = model_dict['out_conv_var.conv.conv.0.bias'] + """ + + # check for size mismatch and exclude layers whose dimensions mismatch (they won't be loaded) + pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()} + model_dict.update(pretrained_dict) + model.netG.load_state_dict(model_dict, strict=False) + + # freeze pretrained weights + model.frozen = True + freeze_layers(model.netG, grad=True) # set all weights to trainable, except final ... + if train_out_layer: + # freeze all but last layer + all_but_last = {k:v for k, v in pretrained_dict.items() if 'out_conv.conv.conv.0' not in k} + freeze_layers(model.netG, apply_to=all_but_last, grad=False) + freeze_these = list(all_but_last.keys()) + else: # freeze all pre-trained layers, without exceptions + freeze_layers(model.netG, apply_to=pretrained_dict, grad=False) + freeze_these = list(pretrained_dict.keys()) + train_these = [train_layer for train_layer in list(model_dict.keys()) if train_layer not in freeze_these] + print(f'\nFroze these layers: {freeze_these}') + print(f'\nTrain these layers: {train_these}') + + if config.resume_from: + resume_at = int(config.trained_checkp.split('.pth.tar')[0].split('_')[-1]) + print(f'\nResuming training at epoch {resume_at+1}/{config.epochs}, loading optimizers and schedulers') + # if continuing training, then also load states of previous runs' optimizers and schedulers + # ---else, we start optimizing from scratch but with the model parameters loaded above + optimizer_G_dict = torch.load(config.trained_checkp, map_location=config.device)["optimizer_G"] + model.optimizer_G.load_state_dict(optimizer_G_dict) + + scheduler_G_dict = torch.load(config.trained_checkp, map_location=config.device)["scheduler_G"] + model.scheduler_G.load_state_dict(scheduler_G_dict) + + # no return value, models are passed by reference + + +# function to load checkpoints of individual and ensemble models +# (this is used for training and testing scripts) +def load_checkpoint(config, checkp_dir, model, name): + print(checkp_dir) + chckp_path = os.path.join(checkp_dir, config.experiment_name, f"{name}.pth.tar") + print(f'Loading checkpoint {chckp_path}') + checkpoint = torch.load(chckp_path, map_location=config.device)["state_dict"] + + try: # try loading checkpoint strictly, all weights & their names must match + model.load_state_dict(checkpoint, strict=True) + except: + # rename keys + # in_block1 -> in_block0, out_block1 -> out_block0 + checkpoint_renamed = dict() + for key, val in checkpoint.items(): + if 'in_block' in key or 'out_block' in key: + strs = key.split('.') + strs[1] = strs[1][:-1] + str(int(strs[1][-1])-1) + strs[1] = '.'.join([strs[1][:-1], strs[1][-1]]) + key = '.'.join(strs) + checkpoint_renamed[key] = val + model.load_state_dict(checkpoint_renamed, strict=False) + +def freeze_layers(net, apply_to=None, grad=False): + if net is not None: + for k, v in net.named_parameters(): + # check if layer is supposed to be frozen + if hasattr(v, 'requires_grad') and v.dtype != torch.int64: + if apply_to is not None: + # flip + if k in apply_to.keys() and v.size() == apply_to[k].size(): + v.requires_grad_(grad) + else: # otherwise apply indiscriminately to all layers + v.requires_grad_(grad) \ No newline at end of file diff --git a/UnCRtainTS/model/src/utils.py b/UnCRtainTS/model/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23052acded0d0534e604a3e197e9d26a1308d7f3 --- /dev/null +++ b/UnCRtainTS/model/src/utils.py @@ -0,0 +1,76 @@ +import re +import collections.abc + +import torch +from torch.nn import functional as F + +np_str_obj_array_pattern = re.compile(r"[SaUO]") + +# map arg string of written list to list +def str2list(config, list_args): + for k, v in vars(config).items(): + if k in list_args and v is not None and isinstance(v, str): + v = v.replace("[", "") + v = v.replace("]", "") + config.__setattr__(k, list(map(int, v.split(",")))) + return config + + + +def pad_tensor(x, l, pad_value=0): + padlen = l - x.shape[0] + pad = [0 for _ in range(2 * len(x.shape[1:]))] + [0, padlen] + return F.pad(x, pad=pad, value=pad_value) + + +def pad_collate(batch, pad_value=0): + # modified default_collate from the official pytorch repo + # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if len(elem.shape) > 0: + sizes = [e.shape[0] for e in batch] + m = max(sizes) + if not all(s == m for s in sizes): + # pad tensors which have a temporal dimension + batch = [pad_tensor(e, m, pad_value=pad_value) for e in batch] + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError("Format not managed : {}".format(elem.dtype)) + + return pad_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, collections.abc.Mapping): + return {key: pad_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type(*(pad_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = zip(*batch) + return [pad_collate(samples) for samples in transposed] + + raise TypeError("Format not managed : {}".format(elem_type)) + + +def get_ntrainparams(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) \ No newline at end of file diff --git a/UnCRtainTS/model/test_reconstruct.py b/UnCRtainTS/model/test_reconstruct.py new file mode 100644 index 0000000000000000000000000000000000000000..7e69bd4b5153ad97b6874fa631a8f68561d40797 --- /dev/null +++ b/UnCRtainTS/model/test_reconstruct.py @@ -0,0 +1,194 @@ +""" +Script for image reconstruction inference with pre-trained models +Author: Patrick Ebel (github/PatrickTUM), based on the scripts of + Vivien Sainte Fare Garnot (github/VSainteuf) +License: MIT +""" + +import os +import sys +import json +import pprint +import argparse +from parse_args import create_parser + +import torch + +dirname = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(dirname)) + +from src import utils +from src.model_utils import get_model, load_checkpoint +from train_reconstruct import ( + iterate, + save_results, + prepare_output, + import_from_path, + seed_packages, +) +from data.dataLoader import SEN12MSCR, get_pairedS1 + +from torch.utils.tensorboard import SummaryWriter + +parser = create_parser(mode="test") +test_config = parser.parse_args() + +# grab the PID so we can look it up in the logged config for server-side process management +test_config.pid = os.getpid() + +# related to flag --use_custom: +# define custom target S2 patches (these will be mosaiced into a single sample), and fetch associated target S1 patches as well as input data +# (TODO: keeping this hard-coded until a more convenient way to pass it as an argument comes about ...) +targ_s2 = [ + f"ROIs1868/73/S2/14/s2_ROIs1868_73_ImgNo_14_2018-06-21_patch_{pdx}.tif" + for pdx in [171, 172, 173, 187, 188, 189, 203, 204, 205] +] + +# load previous config from training directories + +# if no custom path to config file is passed, try fetching config file at default location +conf_path = ( + os.path.join( + dirname, test_config.weight_folder, test_config.experiment_name, "conf.json" + ) + if not test_config.load_config + else test_config.load_config +) +if os.path.isfile(conf_path): + with open(conf_path) as file: + model_config = json.loads(file.read()) + t_args = argparse.Namespace() + # do not overwrite the following flags by their respective values in the config file + no_overwrite = [ + "pid", + "device", + "resume_at", + "trained_checkp", + "res_dir", + "weight_folder", + "root1", + "root2", + "root3", + "max_samples_count", + "batch_size", + "display_step", + "plot_every", + "export_every", + "input_t", + "region", + "min_cov", + "max_cov", + ] + conf_dict = { + key: val for key, val in model_config.items() if key not in no_overwrite + } + for key, val in vars(test_config).items(): + if key in no_overwrite: + conf_dict[key] = val + t_args.__dict__.update(conf_dict) + config = parser.parse_args(namespace=t_args) +else: + config = test_config # otherwise, keep passed flags without any overwriting +config = utils.str2list(config, ["encoder_widths", "decoder_widths", "out_conv"]) + +if config.pretrain: + config.batch_size = 32 + +experime_dir = os.path.join(config.res_dir, config.experiment_name) +if not os.path.exists(experime_dir): + os.makedirs(experime_dir) +with open(os.path.join(experime_dir, "conf.json"), "w") as file: + file.write(json.dumps(vars(config), indent=4)) + +# seed everything +seed_packages(config.rdm_seed) +if __name__ == "__main__": + pprint.pprint(config) + +# instantiate tensorboard logger +writer = SummaryWriter(os.path.join(config.res_dir, config.experiment_name)) + + +if config.use_custom: + print("Testing on custom data samples") + # define a dictionary for the custom sample, with customized ROI and time points + custom = [ + { + "input": { + "S1": [ + get_pairedS1(targ_s2, config.root1, mod="s1", time=tdx) + for tdx in range(0, 3) + ], + "S2": [ + get_pairedS1(targ_s2, config.root1, mod="s2", time=tdx) + for tdx in range(0, 3) + ], + }, + "target": { + "S1": [get_pairedS1(targ_s2, config.root1, mod="s1")], + "S2": [targ_s2], + }, + } + ] + + +def main(config): + device = torch.device(config.device) + prepare_output(config) + + model = get_model(config) + model = model.to(device) + config.N_params = utils.get_ntrainparams(model) + print(f"TOTAL TRAINABLE PARAMETERS: {config.N_params}\n") + # print(model) + + # get data loader + if config.pretrain: + dt_test = SEN12MSCR( + config.root3, + split="test", + region=config.region, + sample_type=config.sample_type, + ) + else: + pass + + dt_test = torch.utils.data.Subset( + dt_test, range(0, min(config.max_samples_count, len(dt_test))) + ) + test_loader = torch.utils.data.DataLoader( + dt_test, batch_size=config.batch_size, shuffle=False + ) + + # Load weights + ckpt_n = f"_epoch_{config.resume_at}" if config.resume_at > 0 else "" + load_checkpoint(config, config.weight_folder, model, f"model{ckpt_n}") + + # Inference + print("Testing . . .") + model.eval() + + _, test_img_metrics = iterate( + model, + data_loader=test_loader, + config=config, + writer=writer, + mode="test", + epoch=1, + device=device, + ) + print(f"\nTest image metrics: {test_img_metrics}") + + save_results( + test_img_metrics, + os.path.join(config.res_dir, config.experiment_name), + split="test", + ) + print( + f"\nLogged test metrics to path {os.path.join(config.res_dir, config.experiment_name)}" + ) + + +if __name__ == "__main__": + main(config) + exit() diff --git a/UnCRtainTS/model/train_reconstruct.py b/UnCRtainTS/model/train_reconstruct.py new file mode 100644 index 0000000000000000000000000000000000000000..d50fdd5ef5fde4e1a636d94beb72529984d1522e --- /dev/null +++ b/UnCRtainTS/model/train_reconstruct.py @@ -0,0 +1,1067 @@ +""" +Main script for image reconstruction experiments +Author: Patrick Ebel (github/PatrickTUM), based on the scripts of + Vivien Sainte Fare Garnot (github/VSainteuf) +License: MIT +""" + + +import os +import sys +import time +import json +import random +import pprint +import argparse +import numpy as np +from tqdm import tqdm +from matplotlib import pyplot as plt + +dirname = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(dirname)) + +from parse_args import create_parser +from data.dataLoader import SEN12MSCR +from src.model_utils import ( + get_model, + save_model, + freeze_layers, + load_model, + load_checkpoint, +) +from src.learning.metrics import img_metrics, avg_img_metrics + +import torch +import torchnet as tnt +from torch.utils.tensorboard import SummaryWriter + +from src import utils, losses +from src.learning.weight_init import weight_init + +S2_BANDS = 13 +parser = create_parser(mode="train") +config = utils.str2list( + parser.parse_args(), list_args=["encoder_widths", "decoder_widths", "out_conv"] +) + +if config.model in ["unet", "utae"]: + assert len(config.encoder_widths) == len(config.decoder_widths) + config.loss = "l2" + if config.model == "unet": + # train U-Net from scratch + config.pretrain = True + config.trained_checkp = "" + +if config.pretrain: # pre-training is on a single time point + config.input_t = config.n_head = 1 + config.sample_type = "pretrain" + if config.model == "unet": + config.batch_size = 32 + config.positional_encoding = False + +if config.loss in ["GNLL", "MGNLL"]: + # for univariate losses, default to univariate mode (batched across channels) + if config.loss in ["GNLL"]: + config.covmode = "uni" + + if config.covmode == "iso": + config.out_conv[-1] += 1 + elif config.covmode in ["uni", "diag"]: + config.out_conv[-1] += S2_BANDS + config.var_nonLinearity = "softplus" + +# grab the PID so we can look it up in the logged config for server-side process management +config.pid = os.getpid() + +# import & re-load a previous configuration, e.g. to resume training +if config.resume_from: + load_conf = os.path.join(config.res_dir, config.experiment_name, "conf.json") + if config.experiment_name != config.trained_checkp.split("/")[-2]: + raise ValueError("Mismatch of loaded config file and checkpoints") + with open(load_conf, "rt") as f: + t_args = argparse.Namespace() + # do not overwrite the following flags by their respective values in the config file + no_overwrite = [ + "pid", + "num_workers", + "root1", + "root2", + "root3", + "resume_from", + "trained_checkp", + "epochs", + "encoder_widths", + "decoder_widths", + "lr", + ] + conf_dict = { + key: val for key, val in json.load(f).items() if key not in no_overwrite + } + for key, val in vars(config).items(): + if key in no_overwrite: + conf_dict[key] = val + t_args.__dict__.update(conf_dict) + config = parser.parse_args(namespace=t_args) +config = utils.str2list( + config, list_args=["encoder_widths", "decoder_widths", "out_conv"] +) + +# resume at a specified epoch and update optimizer accordingly +if config.resume_at >= 0: + config.lr = config.lr * config.gamma**config.resume_at + + +# fix all RNG seeds, +# throw the whole bunch at 'em +def seed_packages(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + # torch.use_deterministic_algorithms(True, warn_only=True) + torch.backends.cudnn.benchmark = False + + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +# seed everything +seed_packages(config.rdm_seed) +# seed generators for train & val/test dataloaders +f, g = torch.Generator(), torch.Generator() +f.manual_seed(config.rdm_seed + 0) # note: this may get re-seeded each epoch +g.manual_seed(config.rdm_seed) # keep this one fixed + +if __name__ == "__main__": + pprint.pprint(config) + +# instantiate tensorboard logger +writer = SummaryWriter( + os.path.join(os.path.dirname(config.res_dir), "logs", config.experiment_name) +) + + +def plot_img(imgs, mod, plot_dir, file_id=None): + if not os.path.exists(plot_dir): + os.makedirs(plot_dir) + try: + imgs = imgs.cpu().numpy() + for tdx, img in enumerate(imgs): # iterate over temporal dimension + time = "" if imgs.shape[0] == 1 else f"_t-{tdx}" + if mod in ["pred", "in", "target", "s2"]: + rgb = [3, 2, 1] if img.shape[0] == S2_BANDS else [5, 4, 3] + img, val_min, val_max = img[rgb, ...], 0, 1 + elif mod == "s1": + img, val_min, val_max = img[[0], ...], 0, 1 + elif mod == "mask": + img, val_min, val_max = img[[0], ...], 0, 1 + elif mod == "err": + img, val_min, val_max = img[[0], ...], 0, 0.01 + elif mod == "var": + img, val_min, val_max = img[[0], ...], 0, 0.000025 + else: + raise NotImplementedError + if file_id is not None: # export into file name + img = img.clip( + val_min, val_max + ) # note: this only removes outliers, vmin/vmax below do the global rescaling (else doing instance-wise min/max scaling) + plt.imsave( + os.path.join(plot_dir, f"img-{file_id}_{mod}{time}.png"), + np.moveaxis(img, 0, -1).squeeze(), + dpi=100, + cmap="gray", + vmin=val_min, + vmax=val_max, + ) + except: + if isinstance(imgs, plt.Figure): # the passed argument is a pre-rendered figure + plt.savefig(os.path.join(plot_dir, f"img-{file_id}_{mod}.png"), dpi=100) + else: + raise NotImplementedError + + +def export(arrs, mod, export_dir, file_id=None): + if not os.path.exists(export_dir): + os.makedirs(export_dir) + for tdx, arr in enumerate(arrs): # iterate over temporal dimension + num = "" if arrs.shape[0] == 1 else f"_t-{tdx}" + np.save(os.path.join(export_dir, f"img-{file_id}_{mod}{num}.npy"), arr.cpu()) + + +def prepare_data(batch, device, config): + if config.pretrain: + return prepare_data_mono(batch, device, config) + else: + return prepare_data_multi(batch, device, config) + + +def prepare_data_mono(batch, device, config): + x = batch["input"]["S2"].to(device).unsqueeze(1) + if config.use_sar: + x = torch.cat((batch["input"]["S1"].to(device).unsqueeze(1), x), dim=2) + m = batch["input"]["masks"].to(device).unsqueeze(1) + y = batch["target"]["S2"].to(device).unsqueeze(1) + return x, y, m + + +def prepare_data_multi(batch, device, config): + in_S2 = recursive_todevice(batch["input"]["S2"], device) + in_S2_td = recursive_todevice(batch["input"]["S2 TD"], device) + if config.batch_size > 1: + in_S2_td = torch.stack((in_S2_td)).T + in_m = torch.stack(recursive_todevice(batch["input"]["masks"], device)).swapaxes( + 0, 1 + ) + target_S2 = recursive_todevice(batch["target"]["S2"], device) + y = torch.cat(target_S2, dim=0).unsqueeze(1) + + if config.use_sar: + in_S1 = recursive_todevice(batch["input"]["S1"], device) + in_S1_td = recursive_todevice(batch["input"]["S1 TD"], device) + if config.batch_size > 1: + in_S1_td = torch.stack((in_S1_td)).T + x = torch.cat((torch.stack(in_S1, dim=1), torch.stack(in_S2, dim=1)), dim=2) + dates = ( + torch.stack((torch.tensor(in_S1_td), torch.tensor(in_S2_td))) + .float() + .mean(dim=0) + .to(device) + ) + else: + x = torch.stack(in_S2, dim=1) + dates = torch.tensor(in_S2_td).float().to(device) + + return x, y, in_m, dates + + +def log_aleatoric(writer, config, mode, step, var, name, img_meter=None): + # if var is of shape [B x 1 x C x C x H x W] then it's a covariance tensor + if len(var.shape) > 5: + covar = var + # get [B x 1 x C x H x W] variance tensor + var = var.diagonal(dim1=2, dim2=3).moveaxis(-1, 2) + + # compute spatial-average to visualize patch-wise covariance matrices + patch_covmat = covar.mean(dim=-1).mean(dim=-1).squeeze(dim=1) + for bdx, img in enumerate(patch_covmat): # iterate over [B x C x C] covmats + img = img.detach().numpy() + + max_abs = max(abs(img.min()), abs(img.max())) + scale_rel_left, scale_rel_right = -max_abs, +max_abs + fig = continuous_matshow(img, min=scale_rel_left, max=scale_rel_right) + writer.add_figure(f"Img/{mode}/patch covmat relative {bdx}", fig, step) + scale_center0_absolute = ( + 1 / 4 * 1**2 + ) # assuming covmat has been rescaled already, this is an upper bound + fig = continuous_matshow( + img, min=-scale_center0_absolute, max=scale_center0_absolute + ) + writer.add_figure(f"Img/{mode}/patch covmat absolute {bdx}", fig, step) + + # aleatoric uncertainty: comput during train, val and test + # note: the quantile statistics are computed solely over the variances (and would be much different if involving covariances, e.g. in the isotopic case) + avg_var = torch.mean( + var, dim=2, keepdim=True + ) # avg over bands, note: this only considers variances (else diag COV's avg would be tiny) + q50 = ( + avg_var[:, 0, ...].view(avg_var.shape[0], -1).median(dim=-1)[0].detach().clone() + ) + q75 = ( + avg_var[:, 0, ...] + .view(avg_var.shape[0], -1) + .quantile(0.75, dim=-1) + .detach() + .clone() + ) + q50, q75 = q50[0], q75[0] # take batch's first item as a summary + binning = 256 # see: https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_histogram + + if config.loss in ["GNLL", "MGNLL"]: + writer.add_image( + f"Img/{mode}/{name}aleatoric [0,1]", + avg_var[0, 0, ...].clip(0, 1), + step, + dataformats="CHW", + ) # map image to [0, 1] + writer.add_image( + f"Img/{mode}/{name}aleatoric [0,q75]", + avg_var[0, 0, ...].clip(0.0, q75) / q75, + step, + dataformats="CHW", + ) # map image to [0, q75] + writer.add_histogram( + f"Hist/{mode}/{name}aleatoric", + avg_var[0, 0, ...].flatten().clip(0, 1), + step, + bins=binning, + max_bins=binning, + ) + else: + raise NotImplementedError + + writer.add_scalar(f"{mode}/{name}aleatoric median all", q50, step) + writer.add_scalar(f"{mode}/{name}aleatoric q75 all", q75, step) + if img_meter is not None: + writer.add_scalar(f"{mode}/{name}UCE SE", img_meter.value()["UCE SE"], step) + writer.add_scalar(f"{mode}/{name}AUCE SE", img_meter.value()["AUCE SE"], step) + + +def log_train(writer, config, model, step, x, out, y, in_m, name="", var=None): + # logged loss is before rescaling by learning rate + _, loss = model.criterion, model.loss_G.cpu() + if name != "": + name = f"model_{name}/" + + writer.add_scalar(f"train/{name}{config.loss}", loss, step) + writer.add_scalar(f"train/{name}total", loss, step) + # use add_images for batch-wise adding across temporal dimension + if config.use_sar: + writer.add_image( + f"Img/train/{name}in_s1", x[0, :, [0], ...], step, dataformats="NCHW" + ) + writer.add_image( + f"Img/train/{name}in_s2", x[0, :, [5, 4, 3], ...], step, dataformats="NCHW" + ) + else: + writer.add_image( + f"Img/train/{name}in_s2", x[0, :, [3, 2, 1], ...], step, dataformats="NCHW" + ) + writer.add_image( + f"Img/train/{name}out", out[0, 0, [3, 2, 1], ...], step, dataformats="CHW" + ) + writer.add_image( + f"Img/train/{name}y", y[0, 0, [3, 2, 1], ...], step, dataformats="CHW" + ) + writer.add_image( + f"Img/train/{name}m", in_m[0, :, None, ...], step, dataformats="NCHW" + ) + + # analyse cloud coverage + + # covered at ALL time points (AND) or covered at ANY time points (OR) + # and_m, or_m = torch.prod(in_m[0,:, ...], dim=0, keepdim=True), torch.sum(in_m[0,:, ...], dim=0, keepdim=True).clip(0,1) + and_m, or_m = torch.prod(in_m, dim=1, keepdim=True), torch.sum( + in_m, dim=1, keepdim=True + ).clip(0, 1) + writer.add_scalar(f"train/{name}OR m %", or_m.float().mean(), step) + writer.add_scalar(f"train/{name}AND m %", and_m.float().mean(), step) + writer.add_image(f"Img/train/{name}AND m", and_m, step, dataformats="NCHW") + writer.add_image(f"Img/train/{name}OR m", or_m, step, dataformats="NCHW") + + and_m_gray = in_m.float().mean(axis=1).cpu() + for bdx, img in enumerate(and_m_gray): + fig = discrete_matshow(img, n_colors=config.input_t) + writer.add_figure(f"Img/train/temp overlay m {bdx}", fig, step) + + if var is not None: + # log aleatoric uncertainty statistics, excluding computation of ECE + log_aleatoric(writer, config, "train", step, var, name, img_meter=None) + + +def discrete_matshow(data, n_colors=5, min=0, max=1): + fig, ax = plt.subplots() + # get discrete colormap + cmap = plt.get_cmap("gray", n_colors + 1) + ax.matshow(data, cmap=cmap, vmin=min, vmax=max) + ax.axis("off") + fig.tight_layout() + return fig + + +def continuous_matshow(data, min=0, max=1): + fig, ax = plt.subplots() + # get discrete colormap + cmap = plt.get_cmap("seismic") + ax.matshow(data, cmap=cmap, vmin=min, vmax=max) + ax.axis("off") + # optionally: provide a colorbar and tick at integers + # cax = plt.colorbar(mat, ticks=np.arange(min, max + 1)) + return fig + + +def iterate(model, data_loader, config, writer, mode="train", epoch=None, device=None): + if len(data_loader) == 0: + raise ValueError("Received data loader with zero samples!") + # loss meter, needs 1 meter per scalar (see https://tnt.readthedocs.io/en/latest/_modules/torchnet/meter/averagevaluemeter.html); + loss_meter = tnt.meter.AverageValueMeter() + img_meter = avg_img_metrics() + + # collect sample-averaged uncertainties and errors + errs, errs_se, errs_ae, vars_aleatoric = [], [], [], [] + + t_start = time.time() + for i, batch in enumerate(tqdm(data_loader)): + step = (epoch - 1) * len(data_loader) + i + + if config.sample_type == "cloudy_cloudfree": + x, y, in_m, dates = prepare_data(batch, device, config) + elif config.sample_type == "pretrain": + x, y, in_m = prepare_data(batch, device, config) + dates = None + else: + raise NotImplementedError + inputs = {"A": x, "B": y, "dates": dates, "masks": in_m} + + if mode != "train": # val or test + with torch.no_grad(): + # compute single-model mean and variance predictions + model.set_input(inputs) + model.forward() + model.get_loss_G() + model.rescale() + out = model.fake_B + if hasattr(model.netG, "variance") and model.netG.variance is not None: + var = model.netG.variance + model.netG.variance = None + else: + var = out[:, :, S2_BANDS:, ...] + out = out[:, :, :S2_BANDS, ...] + batch_size = y.size()[0] + + for bdx in range(batch_size): + # only compute statistics on variance estimates if using e.g. NLL loss or combinations thereof + + if config.loss in ["GNLL", "MGNLL"]: + # if the variance variable is of shape [B x 1 x C x C x H x W] then it's a covariance tensor + if len(var.shape) > 5: + covar = var + # get [B x 1 x C x H x W] variance tensor + var = var.diagonal(dim1=2, dim2=3).moveaxis(-1, 2) + + extended_metrics = img_metrics(y[bdx], out[bdx], var=var[bdx]) + vars_aleatoric.append(extended_metrics["mean var"]) + errs.append(extended_metrics["error"]) + errs_se.append(extended_metrics["mean se"]) + errs_ae.append(extended_metrics["mean ae"]) + else: + extended_metrics = img_metrics(y[bdx], out[bdx]) + + img_meter.add(extended_metrics) + idx = i * batch_size + bdx # plot and export every k-th item + if config.plot_every > 0 and idx % config.plot_every == 0: + plot_dir = os.path.join( + config.res_dir, + config.experiment_name, + "plots", + f"epoch_{epoch}", + f"{mode}", + ) + plot_img(x[bdx], "in", plot_dir, file_id=idx) + plot_img(out[bdx], "pred", plot_dir, file_id=idx) + plot_img(y[bdx], "target", plot_dir, file_id=idx) + plot_img( + ((out[bdx] - y[bdx]) ** 2).mean(1, keepdims=True), + "err", + plot_dir, + file_id=idx, + ) + plot_img( + discrete_matshow( + in_m.float().mean(axis=1).cpu()[bdx], + n_colors=config.input_t, + ), + "mask", + plot_dir, + file_id=idx, + ) + if var is not None: + plot_img( + var.mean(2, keepdims=True)[bdx], + "var", + plot_dir, + file_id=idx, + ) + if config.export_every > 0 and idx % config.export_every == 0: + export_dir = os.path.join( + config.res_dir, + config.experiment_name, + "export", + f"epoch_{epoch}", + f"{mode}", + ) + export(out[bdx], "pred", export_dir, file_id=idx) + export(y[bdx], "target", export_dir, file_id=idx) + if var is not None: + try: + export(covar[bdx], "covar", export_dir, file_id=idx) + except: + export(var[bdx], "var", export_dir, file_id=idx) + else: # training + # compute single-model mean and variance predictions + model.set_input(inputs) + model.optimize_parameters() # not using model.forward() directly + out = model.fake_B.detach().cpu() + + # read variance predictions stored on generator + if hasattr(model.netG, "variance") and model.netG.variance is not None: + var = model.netG.variance.cpu() + else: + var = out[:, :, S2_BANDS:, ...] + out = out[:, :, :S2_BANDS, ...] + + if config.plot_every > 0: + plot_out = out.detach().clone() + batch_size = y.size()[0] + for bdx in range(batch_size): + idx = i * batch_size + bdx # plot and export every k-th item + if idx % config.plot_every == 0: + plot_dir = os.path.join( + config.res_dir, + config.experiment_name, + "plots", + f"epoch_{epoch}", + f"{mode}", + ) + plot_img(x[bdx], "in", plot_dir, file_id=i) + plot_img(plot_out[bdx], "pred", plot_dir, file_id=i) + plot_img(y[bdx], "target", plot_dir, file_id=i) + + if mode == "train": + # periodically log stats + if step % config.display_step == 0: + out, x, y, in_m = out.cpu(), x.cpu(), y.cpu(), in_m.cpu() + if config.loss in ["GNLL", "MGNLL"]: + var = var.cpu() + log_train(writer, config, model, step, x, out, y, in_m, var=var) + else: + log_train(writer, config, model, step, x, out, y, in_m) + + # log the loss, computed via model.backward_G() at train time & via model.get_loss_G() at val/test time + loss_meter.add(model.loss_G.item()) + + # after each batch, close any leftover figures + plt.close("all") + + # --- end of epoch --- + # after each epoch, log the loss metrics + t_end = time.time() + total_time = t_end - t_start + print("Epoch time : {:.1f}s".format(total_time)) + metrics = {f"{mode}_epoch_time": total_time} + # log the loss, only computed within model.backward_G() at train time + metrics[f"{mode}_loss"] = loss_meter.value()[0] + + if mode == "train": # after each epoch, update lr acc. to scheduler + current_lr = model.optimizer_G.state_dict()["param_groups"][0]["lr"] + writer.add_scalar("Etc/train/lr", current_lr, step) + model.scheduler_G.step() + + if mode == "test" or mode == "val": + # log the metrics + + # log image metrics + for key, val in img_meter.value().items(): + writer.add_scalar(f"{mode}/{key}", val, step) + + # any loss is currently only computed within model.backward_G() at train time + writer.add_scalar(f"{mode}/loss", metrics[f"{mode}_loss"], step) + + # use add_images for batch-wise adding across temporal dimension + if config.use_sar: + writer.add_image( + f"Img/{mode}/in_s1", x[0, :, [0], ...], step, dataformats="NCHW" + ) + writer.add_image( + f"Img/{mode}/in_s2", x[0, :, [5, 4, 3], ...], step, dataformats="NCHW" + ) + else: + writer.add_image( + f"Img/{mode}/in_s2", x[0, :, [3, 2, 1], ...], step, dataformats="NCHW" + ) + writer.add_image( + f"Img/{mode}/out", out[0, 0, [3, 2, 1], ...], step, dataformats="CHW" + ) + writer.add_image( + f"Img/{mode}/y", y[0, 0, [3, 2, 1], ...], step, dataformats="CHW" + ) + writer.add_image( + f"Img/{mode}/m", in_m[0, :, None, ...], step, dataformats="NCHW" + ) + + # compute Expected Calibration Error (ECE) + if config.loss in ["GNLL", "MGNLL"]: + sorted_errors_se = compute_ece( + vars_aleatoric, errs_se, len(data_loader.dataset), percent=5 + ) + sorted_errors = {"se_sortAleatoric": sorted_errors_se} + plot_discard( + sorted_errors["se_sortAleatoric"], config, mode, step, is_se=True + ) + + # compute ECE + uce_l2, auce_l2 = compute_uce_auce( + vars_aleatoric, + errs, + len(data_loader.dataset), + percent=5, + l2=True, + mode=mode, + step=step, + ) + + # no need for a running mean here + img_meter.value()["UCE SE"] = uce_l2.cpu().numpy().item() + img_meter.value()["AUCE SE"] = auce_l2.cpu().numpy().item() + + if config.loss in ["GNLL", "MGNLL"]: + log_aleatoric(writer, config, mode, step, var, f"model/", img_meter) + + return metrics, img_meter.value() + else: + return metrics + + +def plot_discard(sorted_errors, config, mode, step, is_se=True): + metric = "SE" if is_se else "AE" + + fig, ax = plt.subplots() + x_axis = np.arange(0.0, 1.0, 0.05) + ax.scatter( + x_axis, + sorted_errors, + c="b", + alpha=1.0, + marker=r".", + label=f"{metric}, sorted by uncertainty", + ) + + # fit a linear regressor with slope b and intercept a + sorted_errors[np.isnan(sorted_errors)] = np.nanmean(sorted_errors) + b, a = np.polyfit(x_axis, sorted_errors, deg=1) + x_seq = np.linspace(0, 1.0, num=1000) + ax.plot( + x_seq, + a + b * x_seq, + c="k", + lw=1.5, + alpha=0.75, + label=f"linear fit, {round(a, 3)} + {round(b, 3)} * x", + ) + plt.xlabel("Fraction of samples, sorted ascendingly by uncertainty") + plt.ylabel("Error") + plt.legend(loc="upper left") + plt.grid() + fig.tight_layout() + writer.add_figure(f"Img/{mode}/discard_uncertain", fig, step) + if mode == "test": # export the final test split plots for print + path_to = os.path.join(config.res_dir, config.experiment_name) + print(f"Logging discard plots to path {path_to}") + fig.savefig( + os.path.join(path_to, f"plot_{mode}_{metric}_discard.png"), + bbox_inches="tight", + dpi=int(1e3), + ) + fig.savefig( + os.path.join(path_to, f"plot_{mode}_{metric}_discard.pdf"), + bbox_inches="tight", + dpi=int(1e3), + ) + + +def compute_ece(vars, errors, n_samples, percent=5): + # rank sample-averaged uncertainties ascendingly, and errors accordingly + _, vars_indices = torch.sort(torch.Tensor(vars)) + errors = torch.Tensor(errors) + errs_sort = errors[vars_indices] + # incrementally remove 5% of errors, ranked by highest uncertainty + bins = torch.linspace(0, n_samples, 100 // percent + 1, dtype=int)[1:] + # get uncertainty-sorted cumulative errors, i.e. at x-tick 65% we report the average error for the 65% most certain predictions + sorted_errors = np.array( + [torch.nanmean(errs_sort[:rdx]).cpu().numpy() for rdx in bins] + ) + + return sorted_errors + + +binarize = lambda arg, n_bins, floor=0, ceil=1: np.digitize( + arg, bins=np.linspace(floor, ceil, num=n_bins)[1:] +) + + +def compute_uce_auce(var, errors, n_samples, percent=5, l2=True, mode="val", step=0): + n_bins = 100 // percent + var, errors = torch.Tensor(var), torch.Tensor(errors) + + # metric: IN: standard deviation & error + # OUT: either root mean variance & root mean squared error or mean standard deviation & mean absolute error + metric = ( + lambda arg: torch.sqrt(torch.mean(arg**2)) + if l2 + else torch.mean(torch.abs(arg)) + ) + m_str = "L2" if l2 else "L1" + + # group uncertainty values into n_bins + var_idx = torch.Tensor(binarize(var, n_bins, floor=var.min(), ceil=var.max())) + + # compute bin-wise statistics, defaults to nan if no data contained in bin + bk_var, bk_err = torch.empty(n_bins), torch.empty(n_bins) + for bin_idx in range(n_bins): # for each of the n_bins ... + bk_var[bin_idx] = metric( + var[var_idx == bin_idx].sqrt() + ) # note: taking the sqrt to wrap into metric function, + bk_err[bin_idx] = metric( + errors[var_idx == bin_idx] + ) # apply same metric function on error + + calib_err = torch.abs( + bk_err - bk_var + ) # calibration error: discrepancy of error vs uncertainty + bk_weight = ( + torch.histogram(var_idx, n_bins)[0] / n_samples + ) # fraction of total data per bin, for bin-weighting + uce = torch.nansum(bk_weight * calib_err) # calc. weighted UCE, + auce = torch.nanmean(calib_err) # calc. unweighted AUCE + + # plot bin-wise error versus bin-wise uncertainty + fig, ax = plt.subplots() + x_min, x_max = bk_var[~bk_var.isnan()].min(), bk_var[~bk_var.isnan()].max() + y_min, y_max = 0, bk_err[~bk_err.isnan()].max() + x_axis = np.linspace(x_min, x_max, num=n_bins) + + ax.plot(x_axis, x_axis) # diagonal reference line + ax.bar( + x_axis, + bk_err, + width=x_axis[1] - x_axis[0], + alpha=0.75, + edgecolor="k", + color="gray", + ) + + plt.xlim(x_min, x_max) + plt.ylim(y_min, y_max) + plt.xlabel("Uncertainty") + plt.ylabel(f"{m_str} Error") + plt.legend(loc="upper left") + plt.grid() + fig.tight_layout() + writer.add_figure(f"Img/{mode}/err_vs_var_{m_str}", fig, step) + + return uce, auce + + +def recursive_todevice(x, device): + if isinstance(x, torch.Tensor): + return x.to(device) + elif isinstance(x, dict): + return {k: recursive_todevice(v, device) for k, v in x.items()} + else: + return [recursive_todevice(c, device) for c in x] + + +def prepare_output(config): + os.makedirs(os.path.join(config.res_dir, config.experiment_name), exist_ok=True) + + +def checkpoint(log, config): + with open( + os.path.join(config.res_dir, config.experiment_name, "trainlog.json"), "w" + ) as outfile: + json.dump(log, outfile, indent=4) + + +def save_results(metrics, path, split="test"): + with open(os.path.join(path, f"{split}_metrics.json"), "w") as outfile: + json.dump(metrics, outfile, indent=4) + + +# check for file of pre-computed statistics, e.g. indices or cloud coverage +def import_from_path(split, config): + if os.path.exists( + os.path.join(os.path.dirname(os.getcwd()), "util", "precomputed") + ): + import_path = os.path.join( + os.path.dirname(os.getcwd()), + "util", + "precomputed", + f"generic_{config.input_t}_{split}_{config.region}_s2cloudless_mask.npy", + ) + else: + import_path = os.path.join( + config.precomputed, + f"generic_{config.input_t}_{split}_{config.region}_s2cloudless_mask.npy", + ) + import_data_path = import_path if os.path.isfile(import_path) else None + return import_data_path + + +def main(config): + prepare_output(config) + device = torch.device(config.device) + + # define data sets + if config.pretrain: # pretrain / training on mono-temporal data + dt_train = SEN12MSCR( + os.path.expanduser(config.root3), + split="train", + region=config.region, + sample_type=config.sample_type, + ) + dt_val = SEN12MSCR( + os.path.expanduser(config.root3), + split="test", + region=config.region, + sample_type=config.sample_type, + ) + dt_test = SEN12MSCR( + os.path.expanduser(config.root3), + split="val", + region=config.region, + sample_type=config.sample_type, + ) + else: + pass + + # wrap to allow for subsampling, e.g. for test runs etc + dt_train = torch.utils.data.Subset( + dt_train, + range( + 0, + min( + config.max_samples_count, + len(dt_train), + int(len(dt_train) * config.max_samples_frac), + ), + ), + ) + dt_val = torch.utils.data.Subset( + dt_val, + range( + 0, + min( + config.max_samples_count, + len(dt_val), + int(len(dt_train) * config.max_samples_frac), + ), + ), + ) + dt_test = torch.utils.data.Subset( + dt_test, + range( + 0, + min( + config.max_samples_count, + len(dt_test), + int(len(dt_train) * config.max_samples_frac), + ), + ), + ) + + # instantiate dataloaders, note: worker_init_fn is needed to get reproducible random samples across runs if vary_samples=True + train_loader = torch.utils.data.DataLoader( + dt_train, + batch_size=config.batch_size, + shuffle=True, + worker_init_fn=seed_worker, + generator=f, + num_workers=config.num_workers, + pin_memory=True, + persistent_workers=True, + ) + val_loader = torch.utils.data.DataLoader( + dt_val, + batch_size=config.batch_size, + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + num_workers=config.num_workers, + pin_memory=True, + persistent_workers=True, + ) + test_loader = torch.utils.data.DataLoader( + dt_test, + batch_size=config.batch_size, + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + num_workers=config.num_workers, + pin_memory=True, + persistent_workers=True, + ) + + print("Train {}, Val {}, Test {}".format(len(dt_train), len(dt_val), len(dt_test))) + + # model definition + # (compiled model hangs up in validation step on some systems, retry in the future for pytorch > 2.0) + model = get_model(config) # torch.compile(get_model(config)) + + # set model properties + model.len_epoch = len(train_loader) + + config.N_params = utils.get_ntrainparams(model) + print("\n\nTrainable layers:") + for name, p in model.named_parameters(): + if p.requires_grad: + print(f"\t{name}") + model = model.to(device) + # do random weight initialization + print("\nInitializing weights randomly.") + model.netG.apply(weight_init) + + if config.trained_checkp and len(config.trained_checkp) > 0: + # load weights from the indicated checkpoint + print(f"Loading weights from (pre-)trained checkpoint {config.trained_checkp}") + load_model( + config, + model, + train_out_layer=False, + load_out_partly=config.model in ["uncrtaints"], + ) + + with open( + os.path.join(config.res_dir, config.experiment_name, "conf.json"), "w" + ) as file: + file.write(json.dumps(vars(config), indent=4)) + print(f"TOTAL TRAINABLE PARAMETERS: {config.N_params}\n") + print(model) + + # Optimizer and Loss + model.criterion = losses.get_loss(config) + + # track best loss, checkpoint at best validation performance + is_better, best_loss = lambda new, prev: new <= prev, float("inf") + + # Training loop + trainlog = {} + + # resume training at scheduler's latest epoch, != 0 if --resume_from + begin_at = ( + config.resume_at + if config.resume_at >= 0 + else model.scheduler_G.state_dict()["last_epoch"] + ) + for epoch in range(begin_at + 1, config.epochs + 1): + print("\nEPOCH {}/{}".format(epoch, config.epochs)) + + # put all networks in training mode again + model.train() + model.netG.train() + + # unfreeze all layers after specified epoch + if epoch > config.unfreeze_after and hasattr(model, "frozen") and model.frozen: + print("Unfreezing all network layers") + model.frozen = False + freeze_layers(model.netG, grad=True) + + # re-seed train generator for each epoch anew, depending on seed choice plus current epoch number + # ~ else, dataloader provides same samples no matter what epoch training starts/resumes from + # ~ note: only re-seed train split dataloader (if config.vary_samples), but keep all others consistent + # ~ if desiring different runs, then the seeds must at least be config.epochs numbers apart + if config.vary_samples: + # condition dataloader samples on current epoch count + f.manual_seed(config.rdm_seed + epoch) + train_loader = torch.utils.data.DataLoader( + dt_train, + batch_size=config.batch_size, + shuffle=True, + worker_init_fn=seed_worker, + generator=f, + num_workers=config.num_workers, + ) + + train_metrics = iterate( + model, + data_loader=train_loader, + config=config, + writer=writer, + mode="train", + epoch=epoch, + device=device, + ) + + # do regular validation steps at the end of each training epoch + if epoch % config.val_every == 0 and epoch > config.val_after: + print("Validation . . . ") + + model.eval() + model.netG.eval() + + val_metrics, val_img_metrics = iterate( + model, + data_loader=val_loader, + config=config, + writer=writer, + mode="val", + epoch=epoch, + device=device, + ) + # use the training loss for validation + print("Using training loss as validation loss") + if "val_loss" in val_metrics: + val_loss = val_metrics["val_loss"] + else: + val_loss = val_metrics["val_loss_ensembleAverage"] + + print(f"Validation Loss {val_loss}") + print(f"validation image metrics: {val_img_metrics}") + save_results( + val_img_metrics, + os.path.join(config.res_dir, config.experiment_name), + split=f"test_epoch_{epoch}", + ) + print( + f"\nLogged validation epoch {epoch} metrics to path {os.path.join(config.res_dir, config.experiment_name)}" + ) + + # checkpoint best model + trainlog[epoch] = {**train_metrics, **val_metrics} + checkpoint(trainlog, config) + if is_better(val_loss, best_loss): + best_loss = val_loss + save_model(config, epoch, model, "model") + else: + trainlog[epoch] = {**train_metrics} + checkpoint(trainlog, config) + + # always checkpoint the current epoch's model + save_model(config, epoch, model, f"model_epoch_{epoch}") + + print(f"Completed current epoch of experiment {config.experiment_name}.") + + # following training, test on hold-out data + print("Testing best epoch . . .") + load_checkpoint(config, config.res_dir, model, "model") + + model.eval() + model.netG.eval() + + test_metrics, test_img_metrics = iterate( + model, + data_loader=test_loader, + config=config, + writer=writer, + mode="test", + epoch=epoch, + device=device, + ) + + if "test_loss" in test_metrics: + test_loss = test_metrics["test_loss"] + else: + test_loss = test_metrics["test_loss_ensembleAverage"] + print(f"Test Loss {test_loss}") + print(f"\nTest image metrics: {test_img_metrics}") + save_results( + test_img_metrics, + os.path.join(config.res_dir, config.experiment_name), + split="test", + ) + print( + f"\nLogged test metrics to path {os.path.join(config.res_dir, config.experiment_name)}" + ) + + # close tensorboard logging + writer.close() + + print(f"Finished training experiment {config.experiment_name}.") + + +if __name__ == "__main__": + main(config) + exit() diff --git a/UnCRtainTS/requirements.txt b/UnCRtainTS/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e22ef11c1b6f6bdb5513dd18fe0111594f08eee0 --- /dev/null +++ b/UnCRtainTS/requirements.txt @@ -0,0 +1,29 @@ +# torchnet~=0.0.4 +# torch_scatter +# pandas~=1.0.4 +# geopandas~=0.8.1 +# tqdm~=4.64.0 +# matplotlib~=3.5.2 +# rasterio~=1.2.10 +# torchgeometry~=0.1.2 +# s2cloudless~=1.6.0 +# natsort +# fvcore +# tensorboard +# protobuf~=3.20.1 + + +matplotlib +torchmetrics==0.6.0 +torchnet==0.0.4 +scipy +scikit-image +rasterio +natsort +tqdm +Pillow +dominate +visdom +fvcore +tensorboard +s2cloudless==1.6.0 \ No newline at end of file diff --git a/UnCRtainTS/show.py b/UnCRtainTS/show.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c09ce3660ca28cde5ccd42b4328277c082bc5f --- /dev/null +++ b/UnCRtainTS/show.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np +import matplotlib.pyplot as plt + +pred = "model/inference/monotemporalL2/export/epoch_1/test/img-0_pred.npy" +target = "model/inference/monotemporalL2/export/epoch_1/test/img-0_target.npy" +var = "model/inference/monotemporalL2/export/epoch_1/test/img-0_var.npy" + +pred = np.load(pred) +target = np.load(target) +var = np.load(var) + +print(pred.shape) # (13, 256, 256) +print(target.shape) # (13, 256, 256) +print(var.shape) # (0, 256, 256) + +print(pred.dtype) # float32 +print(target.dtype) # float32 +print(var.dtype) # float32 + +print(pred.min(), pred.max()) # 0.0 1.0 +print(target.min(), target.max()) # 0.0 1.0 +# print(var.min(), var.max()) # nan nan + +rgb = [3, 2, 1] + +fig, ax = plt.subplots() +# get discrete colormap +cmap = plt.get_cmap("gray", 13) +ax.matshow(np.transpose(pred, (1, 2, 0)), cmap=cmap, vmin=0, vmax=1) +ax.axis("off") +fig.tight_layout() +plt.savefig("pred.png") \ No newline at end of file diff --git a/UnCRtainTS/standalone_dataloader.py b/UnCRtainTS/standalone_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..2192218e7f955703a15faf888839a043f23af285 --- /dev/null +++ b/UnCRtainTS/standalone_dataloader.py @@ -0,0 +1,32 @@ +# minimal Python script to demonstrate utilizing the pyTorch data loader for SEN12MS-CR and SEN12MS-CR-TS + +import os +import torch +from data.dataLoader import SEN12MSCR, SEN12MSCRTS + +if __name__ == '__main__': + + # main parameters for instantiating SEN12MS-CR-TS + dataset = 'SEN12MS-CR-TS' # choose either 'SEN12MS-CR' or 'SEN12MS-CR-TS' + root = '/home/data/' # path to your copy of SEN12MS-CR or SEN12MS-CR-TS + split = 'all' # ROI to sample from, belonging to splits [all | train | val | test] + input_t = 3 # number of input time points to sample, only relevant for SEN12MS-CR-TS + import_path = None # path to importing the suppl. file specifying what time points to load for input and output + sample_type = 'cloudy_cloudfree' # type of samples returned [cloudy_cloudfree | generic] + + assert dataset in ['SEN12MS-CR', 'SEN12MS-CR-TS'] + if dataset =='SEN12MS-CR': loader = SEN12MSCR(os.path.join(root, 'SEN12MSCR'), split=split) + else: loader = SEN12MSCRTS(os.path.join(root, 'SEN12MSCRTS'), split=split, sample_type=sample_type, n_input_samples=input_t, import_data_path=import_path) + dataloader = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=10) + + # iterate over split and do some data accessing for demonstration + for pdx, patch in enumerate(dataloader): + print(f'Fetching {pdx}. batch of data.') + + input_s1 = patch['input']['S1'] + input_s2 = patch['input']['S2'] + input_c = sum(patch['input']['coverage'])/len(patch['input']['coverage']) + output_s2 = patch['target']['S2'] + if dataset=='SEN12MS-CR-TS': + dates_s1 = patch['input']['S1 TD'] + dates_s2 = patch['input']['S2 TD'] diff --git a/UnCRtainTS/test_diffcr_bs32_epoch17.sh b/UnCRtainTS/test_diffcr_bs32_epoch17.sh new file mode 100644 index 0000000000000000000000000000000000000000..200b570d6378759e5ed0f279ac02002d6215487e --- /dev/null +++ b/UnCRtainTS/test_diffcr_bs32_epoch17.sh @@ -0,0 +1 @@ +python model/test_reconstruct.py --experiment_name diffcr_bs32_epoch17 --root3 data2/SEN12MSCR --input_t 1 --region all --export_every 1 --res_dir ./inference --weight_folder checkpoint/ --pretrain --sample_type pretrain --device cuda:6 --use_sar --out_conv 13 \ No newline at end of file diff --git a/UnCRtainTS/test_monotemporalL2.sh b/UnCRtainTS/test_monotemporalL2.sh new file mode 100644 index 0000000000000000000000000000000000000000..f02d962d402018024dea54e662834bf4d16d35fc --- /dev/null +++ b/UnCRtainTS/test_monotemporalL2.sh @@ -0,0 +1 @@ +python model/test_reconstruct.py --experiment_name monotemporalL2 --root3 data2/SEN12MSCR --input_t 1 --region all --export_every 1 --res_dir ./inference --weight_folder checkpoint/ --pretrain --sample_type pretrain --device cuda:6 \ No newline at end of file diff --git a/UnCRtainTS/train_diffcr_bs32.sh b/UnCRtainTS/train_diffcr_bs32.sh new file mode 100644 index 0000000000000000000000000000000000000000..20be7d8b8009c06a4ca2086bfd8948c4e9fbd006 --- /dev/null +++ b/UnCRtainTS/train_diffcr_bs32.sh @@ -0,0 +1,43 @@ +python model/train_reconstruct.py \ +--experiment_name diffcr_bs8 \ +--root3 data2/SEN12MSCR \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity \ +--var_nonLinearity softplus \ +--use_sar \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--positional_encoding \ +--d_k 4 \ +--res_dir ./results \ +--device cuda:0 \ +--display_step 10 \ +--batch_size 32 \ +--num_workers 16 \ +--lr 0.001 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--epochs 20 \ +--trained_checkp "" diff --git a/UnCRtainTS/train_diffcr_bs32_lr1e-3.sh b/UnCRtainTS/train_diffcr_bs32_lr1e-3.sh new file mode 100644 index 0000000000000000000000000000000000000000..a382584a8f75a9b8fdec7e032115a7c3975647b8 --- /dev/null +++ b/UnCRtainTS/train_diffcr_bs32_lr1e-3.sh @@ -0,0 +1,43 @@ +python model/train_reconstruct.py \ +--experiment_name diffcr_bs32_lr1e-3 \ +--root3 data2/SEN12MSCR \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity \ +--var_nonLinearity softplus \ +--use_sar \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--positional_encoding \ +--d_k 4 \ +--res_dir ./results \ +--device cuda:7 \ +--display_step 10 \ +--batch_size 32 \ +--num_workers 16 \ +--lr 0.001 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--epochs 100 \ +--trained_checkp "" diff --git a/UnCRtainTS/train_diffcr_bs32_lr5e-4.sh b/UnCRtainTS/train_diffcr_bs32_lr5e-4.sh new file mode 100644 index 0000000000000000000000000000000000000000..c642945ae858d0d0cf18ac1e920f245df21c0db3 --- /dev/null +++ b/UnCRtainTS/train_diffcr_bs32_lr5e-4.sh @@ -0,0 +1,43 @@ +python model/train_reconstruct.py \ +--experiment_name diffcr_bs32_lr15e-4 \ +--root3 data2/SEN12MSCR \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity \ +--var_nonLinearity softplus \ +--use_sar \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--positional_encoding \ +--d_k 4 \ +--res_dir ./results \ +--device cuda:0 \ +--display_step 10 \ +--batch_size 32 \ +--num_workers 16 \ +--lr 0.0005 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--epochs 100 \ +--trained_checkp "" diff --git a/UnCRtainTS/train_diffcr_bs4.sh b/UnCRtainTS/train_diffcr_bs4.sh new file mode 100644 index 0000000000000000000000000000000000000000..efcd11fa26262116d2b0bed7e92d1858f0920b3a --- /dev/null +++ b/UnCRtainTS/train_diffcr_bs4.sh @@ -0,0 +1,43 @@ +python model/train_reconstruct.py \ +--experiment_name diffcr_bs4 \ +--root3 data2/SEN12MSCR \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity \ +--var_nonLinearity softplus \ +--use_sar \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--positional_encoding \ +--d_k 4 \ +--res_dir ./results \ +--device cuda:0 \ +--display_step 10 \ +--batch_size 4 \ +--num_workers 4 \ +--lr 0.001 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--epochs 20 \ +--trained_checkp "" diff --git a/UnCRtainTS/train_diffcr_bs64.sh b/UnCRtainTS/train_diffcr_bs64.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0e2ebc940b4577c02aebf71a22bbe57ef6cb521 --- /dev/null +++ b/UnCRtainTS/train_diffcr_bs64.sh @@ -0,0 +1,43 @@ +python model/train_reconstruct.py \ +--experiment_name diffcr_bs64 \ +--root3 data2/SEN12MSCR \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity \ +--var_nonLinearity softplus \ +--use_sar \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--positional_encoding \ +--d_k 4 \ +--res_dir ./results \ +--device cuda:3 \ +--display_step 10 \ +--batch_size 64 \ +--num_workers 32 \ +--lr 0.001 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--epochs 20 \ +--trained_checkp "" diff --git a/UnCRtainTS/train_diffcr_bs8.sh b/UnCRtainTS/train_diffcr_bs8.sh new file mode 100644 index 0000000000000000000000000000000000000000..3dd2fd3faadc4a5177d0e6320518f48a1e31dd64 --- /dev/null +++ b/UnCRtainTS/train_diffcr_bs8.sh @@ -0,0 +1,43 @@ +python model/train_reconstruct.py \ +--experiment_name diffcr_bs8 \ +--root3 data2/SEN12MSCR \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity \ +--var_nonLinearity softplus \ +--use_sar \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--positional_encoding \ +--d_k 4 \ +--res_dir ./results \ +--device cuda:2 \ +--display_step 10 \ +--batch_size 8 \ +--num_workers 4 \ +--lr 0.001 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--epochs 20 \ +--trained_checkp "" diff --git a/UnCRtainTS/train_monotemporalL2.sh b/UnCRtainTS/train_monotemporalL2.sh new file mode 100644 index 0000000000000000000000000000000000000000..057bbfec1e985d35e483979b746e1827b2270597 --- /dev/null +++ b/UnCRtainTS/train_monotemporalL2.sh @@ -0,0 +1,45 @@ +python model/train_reconstruct.py \ +--experiment_name monotemporalL2 \ +--model uncrtaints \ +--encoder_widths [128] \ +--decoder_widths [128,128,128,128,128] \ +--out_conv 13 \ +--mean_nonLinearity true \ +--var_nonLinearity softplus \ +--use_sar true \ +--agg_mode att_group \ +--encoder_norm group \ +--decoder_norm batch \ +--n_head 1 \ +--d_model 256 \ +--use_v false \ +--positional_encoding true \ +--d_k 4 \ +--res_dir ./results \ +--device cuda \ +--display_step 10 \ +--batch_size 4 \ +--lr 0.001 \ +--gamma 0.8 \ +--ref_date 2014-04-03 \ +--pad_value 0 \ +--padding_mode reflect \ +--val_every 1 \ +--val_after 0 \ +--pretrain true \ +--input_t 1 \ +--sample_type pretrain \ +--vary_samples true \ +--min_cov 0.0 \ +--max_cov 1.0 \ +--region all \ +--max_samples_count 1000000000 \ +--input_size 256 \ +--plot_every -1 \ +--loss l2 \ +--covmode diag \ +--scale_by 10.0 \ +--separate_out false \ +--resume_from false \ +--epochs 20 \ +--trained_checkp "" diff --git a/UnCRtainTS/util/__init__.py b/UnCRtainTS/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae36f63d8859ec0c60dcbfe67c4ac324e751ddf7 --- /dev/null +++ b/UnCRtainTS/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/UnCRtainTS/util/__pycache__/__init__.cpython-311.pyc b/UnCRtainTS/util/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be1932cb1055de99353ed8634281478283049d71 Binary files /dev/null and b/UnCRtainTS/util/__pycache__/__init__.cpython-311.pyc differ diff --git a/UnCRtainTS/util/__pycache__/detect_cloudshadow.cpython-311.pyc b/UnCRtainTS/util/__pycache__/detect_cloudshadow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01cdaad33f7c4185b3d5a147c4d458c867d2109e Binary files /dev/null and b/UnCRtainTS/util/__pycache__/detect_cloudshadow.cpython-311.pyc differ diff --git a/UnCRtainTS/util/detect_cloudshadow.py b/UnCRtainTS/util/detect_cloudshadow.py new file mode 100644 index 0000000000000000000000000000000000000000..2434e79bd95992c5e5dfc1a0943c0e2820bc55e5 --- /dev/null +++ b/UnCRtainTS/util/detect_cloudshadow.py @@ -0,0 +1,93 @@ +import numpy as np +import scipy +import scipy.signal as scisig + + +def rescale(data, limits): + return (data - limits[0]) / (limits[1] - limits[0]) + + +def normalized_difference(channel1, channel2): + subchan = channel1 - channel2 + sumchan = channel1 + channel2 + sumchan[sumchan == 0] = 0.001 # checking for 0 divisions + return subchan / sumchan + + +def get_shadow_mask(data_image): + data_image = data_image / 10000. + + (ch, r, c) = data_image.shape + shadowmask = np.zeros((r, c)).astype('float32') + + BB = data_image[1] + BNIR = data_image[7] + BSWIR1 = data_image[11] + + CSI = (BNIR + BSWIR1) / 2. + + t3 = 3/4 # cloud-score index threshold + T3 = np.min(CSI) + t3 * (np.mean(CSI) - np.min(CSI)) + + t4 = 5 / 6 # water-body index threshold + T4 = np.min(BB) + t4 * (np.mean(BB) - np.min(BB)) + + shadow_tf = np.logical_and(CSI < T3, BB < T4) + + shadowmask[shadow_tf] = -1 + shadowmask = scisig.medfilt2d(shadowmask, 5) + + return shadowmask + + +def get_cloud_mask(data_image, cloud_threshold, binarize=False, use_moist_check=False): + '''Adapted from https://github.com/samsammurphy/cloud-masking-sentinel2/blob/master/cloud-masking-sentinel2.ipynb''' + + data_image = data_image / 10000. + (ch, r, c) = data_image.shape + + # Cloud until proven otherwise + score = np.ones((r, c)).astype('float32') + # Clouds are reasonably bright in the blue and aerosol/cirrus bands. + score = np.minimum(score, rescale(data_image[1], [0.1, 0.5])) + score = np.minimum(score, rescale(data_image[0], [0.1, 0.3])) + score = np.minimum(score, rescale((data_image[0] + data_image[10]), [0.4, 0.9])) + score = np.minimum(score, rescale((data_image[3] + data_image[2] + data_image[1]), [0.2, 0.8])) + + if use_moist_check: + # Clouds are moist + ndmi = normalized_difference(data_image[7], data_image[11]) + score = np.minimum(score, rescale(ndmi, [-0.1, 0.1])) + + # However, clouds are not snow. + ndsi = normalized_difference(data_image[2], data_image[11]) + score = np.minimum(score, rescale(ndsi, [0.8, 0.6])) + + boxsize = 7 + box = np.ones((boxsize, boxsize)) / (boxsize ** 2) + + score = scipy.ndimage.morphology.grey_closing(score, size=(5, 5)) + score = scisig.convolve2d(score, box, mode='same') + + score = np.clip(score, 0.00001, 1.0) + + if binarize: + score[score >= cloud_threshold] = 1 + score[score < cloud_threshold] = 0 + + return score + +# IN: [13 x H x W] S2 image (of arbitrary resolution H,W), scalar cloud detection threshold +# OUT: cloud & shadow segmentation mask (of same resolution) +# the multispectral S2 images are expected to have their default ranges and not be value-standardized yet +# cloud_threshold: the higher the more conservative the masks (i.e. less pixels labeled clouds/shadows) +def get_cloud_cloudshadow_mask(data_image, cloud_threshold): + cloud_mask = get_cloud_mask(data_image, cloud_threshold, binarize=True) + shadow_mask = get_shadow_mask(data_image) + + # encode clouds and shadows as segmentation masks + cloud_cloudshadow_mask = np.zeros_like(cloud_mask) + cloud_cloudshadow_mask[shadow_mask < 0] = -1 + cloud_cloudshadow_mask[cloud_mask > 0] = 1 + + return cloud_cloudshadow_mask diff --git a/UnCRtainTS/util/dl_data.sh b/UnCRtainTS/util/dl_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..b731f27f48643e6f8ef89749534cdc96f876a562 --- /dev/null +++ b/UnCRtainTS/util/dl_data.sh @@ -0,0 +1,311 @@ +#!/bin/bash + +# Script to download, extract and arrange SEN12MS-CR-TS and SEN12MS-CR. +# Make this script executable (by running: chmod +x dl_data.sh), +# then give it a run (by calling: ./dl_data.sh) and +# follow the prompts in order to get the desired data. + +clear +echo "This script is for downloading the SEN12MS-CR-TS data set for cloud removal in satellite data." +echo See the associated paper: Ebel et al \(2022\) \'SEN12MS-CR-TS: A Remote Sensing Data Set for Multi-modal Multi-temporal Cloud Removal\' +echo -e 'Click \e]8;;https://patricktum.github.io/cloud_removal/\ahere\e]8;;\a for more information' +echo +echo + +while true; do + read -p "Do you wish to download the multitemporal SEN12MS-CR-TS data set? " yn + case $yn in + [Yy]* ) SEN12MSCRTS=true; break;; + [Nn]* ) SEN12MSCRTS=false; break;; + * ) echo "Please answer yes or no.";; + esac +done + +if [ "$SEN12MSCRTS" = "true" ]; then + while true; do + read -p "What regions would you like to download? [all|africa|america|asiaEast|asiaWest|europa] " region + case $region in + all|africa|america|asiaEast|asiaWest|europa ) reg=$region; break;; + * ) echo "Please answer [all|africa|america|asiaEast|asiaWest|europa].";; + esac + done +fi + +while true; do + read -p "Do you wish to also download the monotemporal SEN12MS-CR data set (all regions)? " yn + case $yn in + [Yy]* ) SEN12MSCR=true; break;; + [Nn]* ) SEN12MSCR=false; break;; + * ) echo "Please answer yes or no.";; + esac +done + +while true; do + read -p "Do you wish to also download the Sentinel-1 radar data associated with your previous choices? " yn + case $yn in + [Yy]* ) S1=true; break;; + [Nn]* ) S1=false; break;; + * ) echo "Please answer yes or no.";; + esac +done + +declare -A url_dict # holding links to data +declare -A vol_dict # bookkeeping size of data + +echo "Please enter the path to download and extract the data to: " +read dl_extract_to + + +echo +echo +if [ "$SEN12MSCRTS" = "true" ]; then + + echo "Downloading SEN12MS-CR-TS data set." + mkdir -p $dl_extract_to'/SEN12MSCRTS' + + # train split + case $region in + 'all') url_dict['multi_s2_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_africa.tar.gz' + vol_dict['multi_s2_africa']='98233900' + + url_dict['multi_s2_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_america.tar.gz' + vol_dict['multi_s2_america']='110245004' + + url_dict['multi_s2_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaEast.tar.gz' + vol_dict['multi_s2_asiaEast']='113948560' + + url_dict['multi_s2_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaWest.tar.gz' + vol_dict['multi_s2_asiaWest']='96082796' + + url_dict['multi_s2_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_europa.tar.gz' + vol_dict['multi_s2_europa']='196669740' + ;; + 'africa') url_dict['multi_s2_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_africa.tar.gz' + vol_dict['multi_s2_africa']='98233900' + ;; + 'america') url_dict['multi_s2_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_america.tar.gz' + vol_dict['multi_s2_america']='110245004' + ;; + 'asiaEast') url_dict['multi_s2_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaEast.tar.gz' + vol_dict['multi_s2_asiaEast']='113948560' + ;; + 'asiaWest') url_dict['multi_s2_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_asiaWest.tar.gz' + vol_dict['multi_s2_asiaWest']='96082796' + ;; + 'europa') url_dict['multi_s2_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s2_europa.tar.gz' + vol_dict['multi_s2_europa']='196669740' + ;; + esac + + + # test split + case $region in + 'all') url_dict['multi_s2_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_africa_test.tar.gz' + vol_dict['multi_s2_africa_test']='25421744' + + url_dict['multi_s2_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_america_test.tar.gz' + vol_dict['multi_s2_america_test']='25421824' + + url_dict['multi_s2_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaEast_test.tar.gz' + vol_dict['multi_s2_asiaEast_test']='40534760' + + url_dict['multi_s2_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaWest_test.tar.gz' + vol_dict['multi_s2_asiaWest_test']='15012924' + + url_dict['multi_s2_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_europa_test.tar.gz' + vol_dict['multi_s2_europa_test']='79568460' + ;; + 'africa') url_dict['multi_s2_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_africa_test.tar.gz' + vol_dict['multi_s2_africa_test']='25421744' + ;; + 'america') url_dict['multi_s2_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_america_test.tar.gz' + vol_dict['multi_s2_america_test']='25421824' + ;; + 'asiaEast') url_dict['multi_s2_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaEast_test.tar.gz' + vol_dict['multi_s2_asiaEast_test']='40534760' + ;; + 'asiaWest') url_dict['multi_s2_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_asiaWest_test.tar.gz' + vol_dict['multi_s2_asiaWest_test']='15012924' + ;; + 'europa') url_dict['multi_s2_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s2_europa_test.tar.gz' + vol_dict['multi_s2_europa_test']='79568460' + ;; + esac + + + if [ "$S1" = "true" ]; then + # train split + case $region in + 'all') url_dict['multi_s1_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_africa.tar.gz' + vol_dict['multi_s1_africa']='60544524' + + url_dict['multi_s1_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_america.tar.gz' + vol_dict['multi_s1_america']='67947416' + + url_dict['multi_s1_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaEast.tar.gz' + vol_dict['multi_s1_asiaEast']='70230104' + + url_dict['multi_s1_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaWest.tar.gz' + vol_dict['multi_s1_asiaWest']='59218848' + + url_dict['multi_s1_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_europa.tar.gz' + vol_dict['multi_s1_europa']='121213836' + ;; + 'africa') url_dict['multi_s1_africa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_africa.tar.gz' + vol_dict['multi_s1_africa']='60544524' + ;; + 'america') url_dict['multi_s1_america']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_america.tar.gz' + vol_dict['multi_s1_america']='67947416' + ;; + 'asiaEast') url_dict['multi_s1_asiaEast']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaEast.tar.gz' + vol_dict['multi_s1_asiaEast']='70230104' + ;; + 'asiaWest') url_dict['multi_s1_asiaWest']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_asiaWest.tar.gz' + vol_dict['multi_s1_asiaWest']='59218848' + ;; + 'europa') url_dict['multi_s1_europa']='https://dataserv.ub.tum.de/s/m1639953/download?path=/&files=s1_europa.tar.gz' + vol_dict['multi_s1_europa']='121213836' + ;; + esac + + + # test split + case $region in + 'all') url_dict['multi_s1_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_africa_test.tar.gz' + vol_dict['multi_s1_africa_test']='15668120' + + url_dict['multi_s1_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_america_test.tar.gz' + vol_dict['multi_s1_america_test']='15668160' + + url_dict['multi_s1_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaEast_test.tar.gz' + vol_dict['multi_s1_asiaEast_test']='24982736' + + url_dict['multi_s1_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaWest_test.tar.gz' + vol_dict['multi_s1_asiaWest_test']='9252904' + + url_dict['multi_s1_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_europa_test.tar.gz' + vol_dict['multi_s1_europa_test']='49040432' + ;; + 'africa') url_dict['multi_s1_africa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_africa_test.tar.gz' + vol_dict['multi_s1_africa_test']='15668120' + ;; + 'america') url_dict['multi_s1_america_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_america_test.tar.gz' + vol_dict['multi_s1_america_test']='15668160' + ;; + 'asiaEast') url_dict['multi_s1_asiaEast_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaEast_test.tar.gz' + vol_dict['multi_s1_asiaEast_test']='24982736' + ;; + 'asiaWest') url_dict['multi_s1_asiaWest_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_asiaWest_test.tar.gz' + vol_dict['multi_s1_asiaWest_test']='9252904' + ;; + 'europa') url_dict['multi_s1_europa_test']='https://dataserv.ub.tum.de/s/m1659251/download?path=/&files=s1_europa_test.tar.gz' + vol_dict['multi_s1_europa_test']='49040432' + ;; + esac + fi +fi + + +# mono-temporal data (all regions) +if [ "$SEN12MSCR" = "true" ]; then + echo "Also downloading SEN12MS-CR data set." + mkdir -p $dl_extract_to'/SEN12MSCR' + url_dict['mono_s2_spring']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1158_spring_s2.tar.gz' + vol_dict['mono_s2_spring']='48568904' + + url_dict['mono_s2_summer']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1868_summer_s2.tar.gz' + vol_dict['mono_s2_summer']='56425520' + + url_dict['mono_s2_fall']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1970_fall_s2.tar.gz' + vol_dict['mono_s2_fall']='68291864' + + url_dict['mono_s2_winter']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs2017_winter_s2.tar.gz' + vol_dict['mono_s2_winter']='30580552' + + url_dict['mono_s2_cloudy_spring']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1158_spring_s2_cloudy.tar.gz' + vol_dict['mono_s2_cloudy_spring']='48569368' + + url_dict['mono_s2_cloudy_summer']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1868_summer_s2_cloudy.tar.gz' + vol_dict['mono_s2_cloudy_summer']='56426004' + + url_dict['mono_s2_cloudy_fall']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1970_fall_s2_cloudy.tar.gz' + vol_dict['mono_s2_cloudy_fall']='68292448' + + url_dict['mono_s2_cloudy_winter']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs2017_winter_s2_cloudy.tar.gz' + vol_dict['mono_s2_cloudy_winter']='30580812' + + # S1 data of SEN12MS-CR + if [ "$S1" = "true" ]; then + echo "Also downloading associated S1 data." + url_dict['mono_s1_spring']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1158_spring_s1.tar.gz' + vol_dict['mono_s1_spring']='15026120' + + url_dict['mono_s1_summer']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1868_summer_s1.tar.gz' + vol_dict['mono_s1_summer']='17456784' + + url_dict['mono_s1_fall']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs1970_fall_s1.tar.gz' + vol_dict['mono_s1_fall']='21127832' + + url_dict['mono_s1_winter']='https://dataserv.ub.tum.de/s/m1554803/download?path=/&files=ROIs2017_winter_s1.tar.gz' + vol_dict['mono_s1_winter']='9460956' + fi +fi + +req=0 +# integrate file size across archives +for key in "${!vol_dict[@]}"; do + # for each archive: sum up + curr=${vol_dict[$key]} + req=$((req+curr)) +done + +echo +echo +# df -h $dl_extract_to +avail=$(df $dl_extract_to | awk 'NR==2 { print $4 }') +if (( avail < req )); then + echo "Not enough space (512-byte disk sectors) on path "$dl_extract_to". Available "$avail". Required "$req #>&2 + exit 1 +else + echo "Consuming "$req" of "$avail" (512-byte disk sectors) on path "$dl_extract_to +fi +echo +echo + +# download each archive individually, then extract individually + +# fetch the actual data +for key in "${!url_dict[@]}"; do + url=${url_dict[$key]} + filename=$(basename "$url") + filename=${filename:7} + # download + wget --no-check-certificate -c -O $dl_extract_to'/'$filename ${url_dict[$key]} + # unzip and delete archive + tar --extract --file $dl_extract_to'/'$filename -C $dl_extract_to + rm $dl_extract_to'/'$filename +done + +# move the extracted data to its respective place (this may take a while, because we use rsync rather than mv) +echo "Moving data in place, please don't stop this process." +for key in "${!url_dict[@]}"; do + url=${url_dict[$key]} + filename=$(basename "$url") + filename=${filename:7:-7} # remove base URL and trailing *.tar.gz + if [[ ${url_dict[$key]} == *"m1554803"* ]]; then + # move to SEN12MSCR directory + mv $dl_extract_to'/'$filename $dl_extract_to'/SEN12MSCR' + elif [[ ${url_dict[$key]} == *"m1639953"* ]]; then + # move train ROI to SEN12MSCRTS directory + no_prefix_filename=${filename:3} + rsync -a -remove-source-files $dl_extract_to'/'$no_prefix_filename/* $dl_extract_to'/SEN12MSCRTS' 2>/dev/null + rm -rf $dl_extract_to'/'$no_prefix_filename + else + # move test ROI to SEN12MSCRTS directory + rsync -a -remove-source-files $dl_extract_to'/'$filename/* $dl_extract_to'/SEN12MSCRTS' + rm -rf $dl_extract_to'/'$filename + fi +done + +echo +echo "Completed downloading, extracting and moving data! Enjoy :)" diff --git a/UnCRtainTS/util/hdf5converter/script_tif2hdf5.sh b/UnCRtainTS/util/hdf5converter/script_tif2hdf5.sh new file mode 100644 index 0000000000000000000000000000000000000000..feb5da4d84c59f77ad99edc96344acedca43921f --- /dev/null +++ b/UnCRtainTS/util/hdf5converter/script_tif2hdf5.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +# code kindly provided by Corinne Stucker + +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val europa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val america /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val africa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ + +python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val asiaWest /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS/ val asiaEast /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test europa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test america /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test africa /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test asiaWest /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ +#python tif2hdf5.py /scratch2/Data/SEN12MS-CR-TS_testSplit/ test asiaEast /scratch2/Data/SEN12MS-CR-TS_hdf5/all_ROIs/ diff --git a/UnCRtainTS/util/hdf5converter/sen12mscrts_to_hdf5.py b/UnCRtainTS/util/hdf5converter/sen12mscrts_to_hdf5.py new file mode 100644 index 0000000000000000000000000000000000000000..b1585e89070b9094f31a28aec5478feb43c994ca --- /dev/null +++ b/UnCRtainTS/util/hdf5converter/sen12mscrts_to_hdf5.py @@ -0,0 +1,211 @@ +# code kindly provided by Corinne Stucker + +from natsort import natsorted +import numpy as np +import os +import rasterio +from tqdm import tqdm +from scipy.ndimage import gaussian_filter +from s2cloudless import S2PixelCloudDetector + +from data.dataLoader import SEN12MSCRTS + +""" SEN12MSCRTS data loader class, used to load the data in the original format and prepare the data for hdf5 export + + IN: + root: str, path to your copy of the SEN12MS-CR-TS data set + split: str, in [all | train | val | test] + region: str, [all | africa | america | asiaEast | asiaWest | europa] + cloud_masks: str, type of cloud mask detector to run on optical data, in [None | cloud_cloudshadow_mask | s2cloudless_map | s2cloudless_mask] + + OUT: + data_loader: SEN12MSCRTS instance, implements an iterator that can be traversed via __getitem__(pdx), + which returns the pdx-th dictionary of patch-samples (whose structure depends on sample_type) +""" + + +class SEN12MSCRTS_to_hdf5(SEN12MSCRTS): + def __init__(self, root, split="all", region='all', cloud_masks='s2cloudless_mask', modalities=["S1", "S2"]): + + self.root_dir = root # set root directory which contains all ROI + self.region = region # region according to which the ROI are selected + self.ROI = {'ROIs1158': ['106'], + 'ROIs1868': ['17', '36', '56', '73', '85', '100', '114', '119', '121', '126', '127', '139', '142', + '143'], + 'ROIs1970': ['20', '21', '35', '40', '57', '65', '71', '82', '83', '91', '112', '116', '119', '128', + '132', '133', '135', '139', '142', '144', '149'], + 'ROIs2017': ['8', '22', '25', '32', '49', '61', '63', '69', '75', '103', '108', '115', '116', '117', + '130', '140', '146']} + + # define splits conform with SEN12MS-CR + self.splits = {} + if self.region == 'all': + all_ROI = [os.path.join(key, val) for key, vals in self.ROI.items() for val in vals] + self.splits['test'] = [os.path.join('ROIs1868', '119'), os.path.join('ROIs1970', '139'), + os.path.join('ROIs2017', '108'), os.path.join('ROIs2017', '63'), + os.path.join('ROIs1158', '106'), os.path.join('ROIs1868', '73'), + os.path.join('ROIs2017', '32'), + os.path.join('ROIs1868', '100'), os.path.join('ROIs1970', '132'), + os.path.join('ROIs2017', '103'), os.path.join('ROIs1868', '142'), + os.path.join('ROIs1970', '20'), + os.path.join('ROIs2017', '140')] # official test split, across continents + self.splits['val'] = [os.path.join('ROIs2017', '22'), os.path.join('ROIs1970', '65'), + os.path.join('ROIs2017', '117'), os.path.join('ROIs1868', '127'), + os.path.join('ROIs1868', '17')] # insert your favorite validation split here + self.splits['train'] = [roi for roi in all_ROI if roi not in self.splits['val'] and roi not in self.splits[ + 'test']] # all remaining ROI are used for training + elif self.region == 'africa': + self.splits['test'] = [os.path.join('ROIs2017', '32'), os.path.join('ROIs2017', '140')] + self.splits['val'] = [os.path.join('ROIs2017', '22')] + self.splits['train'] = [os.path.join('ROIs1970', '21'), os.path.join('ROIs1970', '35'), + os.path.join('ROIs1970', '40'), + os.path.join('ROIs2017', '8'), os.path.join('ROIs2017', '61'), + os.path.join('ROIs2017', '75')] + elif self.region == 'america': + self.splits['test'] = [os.path.join('ROIs1158', '106'), os.path.join('ROIs1970', '132')] + self.splits['val'] = [os.path.join('ROIs1970', '65')] + self.splits['train'] = [os.path.join('ROIs1868', '36'), os.path.join('ROIs1868', '85'), + os.path.join('ROIs1970', '82'), os.path.join('ROIs1970', '142'), + os.path.join('ROIs2017', '49'), os.path.join('ROIs2017', '116')] + elif self.region == 'asiaEast': + self.splits['test'] = [os.path.join('ROIs1868', '73'), os.path.join('ROIs1868', '119'), + os.path.join('ROIs1970', '139')] + self.splits['val'] = [os.path.join('ROIs2017', '117')] + self.splits['train'] = [os.path.join('ROIs1868', '114'), os.path.join('ROIs1868', '126'), + os.path.join('ROIs1868', '143'), + os.path.join('ROIs1970', '116'), os.path.join('ROIs1970', '135'), + os.path.join('ROIs2017', '25')] + elif self.region == 'asiaWest': + self.splits['test'] = [os.path.join('ROIs1868', '100')] + self.splits['val'] = [os.path.join('ROIs1868', '127')] + self.splits['train'] = [os.path.join('ROIs1970', '57'), os.path.join('ROIs1970', '83'), + os.path.join('ROIs1970', '112'), + os.path.join('ROIs2017', '69'), os.path.join('ROIs1970', '115'), + os.path.join('ROIs1970', '130')] + elif self.region == 'europa': + self.splits['test'] = [os.path.join('ROIs2017', '63'), os.path.join('ROIs2017', '103'), + os.path.join('ROIs2017', '108'), + os.path.join('ROIs1868', '142'), os.path.join('ROIs1970', '20')] + self.splits['val'] = [os.path.join('ROIs1868', '17')] + self.splits['train'] = [os.path.join('ROIs1868', '56'), os.path.join('ROIs1868', '121'), + os.path.join('ROIs1868', '139'), + os.path.join('ROIs1970', '71'), os.path.join('ROIs1970', '91'), + os.path.join('ROIs1970', '119'), + os.path.join('ROIs1970', '128'), os.path.join('ROIs1970', '133'), + os.path.join('ROIs1970', '144'), + os.path.join('ROIs1970', '149'), + os.path.join('ROIs2017', '146')] + else: + raise NotImplementedError + + self.splits["all"] = self.splits["train"] + self.splits["test"] + self.splits["val"] + self.split = split + + assert split in ['all', 'train', 'val', + 'test'], "Input dataset must be either assigned as all, train, test, or val!" + assert cloud_masks in [None, 'cloud_cloudshadow_mask', 's2cloudless_map', + 's2cloudless_mask'], "Unknown cloud mask type!" + + self.modalities = modalities + self.time_points = range(30) + self.cloud_masks = cloud_masks # e.g. 'cloud_cloudshadow_mask', 's2cloudless_map', 's2cloudless_mask' + + if self.cloud_masks in ['s2cloudless_map', 's2cloudless_mask']: + self.cloud_detector = S2PixelCloudDetector(threshold=0.4, all_bands=True, average_over=4, dilation_size=2) + + self.paths = self.get_paths() + self.n_samples = len(self.paths) + + # raise a warning that no data has been found + if not self.n_samples: self.throw_warn() + + def get_paths(self): # assuming for the same ROI+num, the patch numbers are the same + print(f'\nProcessing paths for {self.split} split of region {self.region}') + + paths = [] + for roi_dir, rois in self.ROI.items(): + for roi in tqdm(rois): + roi_path = os.path.join(self.root_dir, roi_dir, roi) + # skip non-existent ROI or ROI not part of the current data split + if not os.path.isdir(roi_path) or os.path.join(roi_dir, roi) not in self.splits[self.split]: continue + path_s1_t, path_s2_t = [], [] + for tdx in self.time_points: + if 'S1' in self.modalities: + path_s1_complete = os.path.join(roi_path, 'S1', str(tdx)) + path_s1 = os.path.join(roi_dir, roi, 'S1', str(tdx)) + s1_t = natsorted([os.path.join(path_s1, f) for f in os.listdir(path_s1_complete) if + (os.path.isfile(os.path.join(path_s1_complete, f)) and ".tif" in f)]) + if 'S2' in self.modalities: + path_s2_complete = os.path.join(roi_path, 'S2', str(tdx)) + path_s2 = os.path.join(roi_dir, roi, 'S2', str(tdx)) + s2_t = natsorted([os.path.join(path_s2, f) for f in os.listdir(path_s2_complete) if + (os.path.isfile(os.path.join(path_s2_complete, f)) and ".tif" in f)]) + + if 'S1' in self.modalities and 'S2' in self.modalities: + # same number of patches + assert len(s1_t) == len(s2_t) + + # sort via file names according to patch number and store + if 'S1' in self.modalities: + path_s1_t.append(s1_t) + if 'S2' in self.modalities: + path_s2_t.append(s2_t) + + # for each patch of the ROI, collect its time points and make this one sample + for pdx in range(len(path_s1_t[0])): + sample = dict() + if 'S1' in self.modalities: + sample['S1'] = [path_s1_t[tdx][pdx] for tdx in self.time_points] + if 'S2' in self.modalities: + sample['S2'] = [path_s2_t[tdx][pdx] for tdx in self.time_points] + + paths.append(sample) + + return paths + + def get_cloud_mask(self, img, mask_type): + if mask_type == 'cloud_cloudshadow_mask': + threshold = 0.2 # set to e.g. 0.2 or 0.4 + mask = self.get_cloud_cloudshadow_mask(np.clip(img, 0, 10000), threshold) + elif mask_type == 's2cloudless_map': + threshold = 0.5 + mask = self.cloud_detector.get_cloud_probability_maps(np.moveaxis(np.clip(img, 0, 10000)/10000, 0, -1)[None, ...])[0, ...] + mask[mask < threshold] = 0 + mask = gaussian_filter(mask, sigma=2).astype(np.float32) + elif mask_type == 's2cloudless_mask': + mask = self.cloud_detector.get_cloud_masks(np.moveaxis(np.clip(img, 0, 10000)/10000, 0, -1)[None, ...])[0, ...] + elif mask_type == 's2cloud_prob': + mask = self.cloud_detector.get_cloud_probability_maps(np.moveaxis(np.clip(img, 0, 10000) / 10000, 0, -1)[None, ...])[0, ...] + + return mask + + def __getitem__(self, pdx): # get the time series of one patch + + sample = dict() + + if 'S1' in self.modalities: + s1 = [self.read_img(os.path.join(self.root_dir, img)) for img in self.paths[pdx]['S1']] + s1_dates = [img.split('/')[-1].split('_')[5] for img in self.paths[pdx]['S1']] + sample['S1'] = s1 + sample['S1_dates'] = s1_dates + sample['S1_paths'] = self.paths[pdx]['S1'] + + if 'S2' in self.modalities: + s2 = [self.read_img(os.path.join(self.root_dir, img)) for img in self.paths[pdx]['S2']] + s2_dates = [img.split('/')[-1].split('_')[5] for img in self.paths[pdx]['S2']] + + cloud_prob = [self.get_cloud_mask(img, 's2cloud_prob') for img in s2] + cloud_mask = [self.get_cloud_mask(img, 's2cloudless_mask') for img in s2] + + sample['S2'] = s2 + sample['S2_dates'] = s2_dates + sample['S2_paths'] = self.paths[pdx]['S2'] + sample['cloud_prob'] = cloud_prob + + sample['cloud_mask'] = cloud_mask + + return sample + + def __len__(self): + # length of generated list + return self.n_samples diff --git a/UnCRtainTS/util/pre_compute_data_samples.py b/UnCRtainTS/util/pre_compute_data_samples.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b91b75fcebe5f69a3deb40753efed5ecb0e937 --- /dev/null +++ b/UnCRtainTS/util/pre_compute_data_samples.py @@ -0,0 +1,128 @@ +""" + Python script to pre-compute cloud coverage statistics on the data of SEN12MS-CR-TS. + The data loader performs online sampling of input and target patches depending on its flags + (e.g.: split, region, n_input_samples, min_cov, max_cov, ) and the patches' calculated cloud coverage. + If using sampler='random', patches can also vary across epochs to act as data augmentation mechanism. + + However, online computing of cloud masks can slow down data loading. A solution is to pre-compute + cloud coverage an relief the dataloader from re-computing each sample, which is what this script offers. + Currently, pre-calculated statistics are exported in an *.npy file, a collection of which is readily + available for download via https://syncandshare.lrz.de/getlink/fiHhwCqr7ch3X39XoGYaUGM8/splits + + Pre-computed statistics can be imported via the dataloader's "import_data_path" argument. +""" + +import os +import sys +import time +import random +import numpy as np +from tqdm import tqdm + +import resource +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +# see: https://docs.python.org/3/library/resource.html#resource.RLIM_INFINITY +resource.setrlimit(resource.RLIMIT_NOFILE, (int(1024*1e3), rlimit[1])) + +import torch +dirname = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(dirname)) +from data.dataLoader import SEN12MSCRTS + +# fix all RNG seeds +seed = 1 +np.random.seed(seed) +torch.manual_seed(seed) + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +g = torch.Generator() +g.manual_seed(seed) + + +pathify = lambda path_list: [os.path.join(*path[0].split('/')[-6:]) for path in path_list] + +if __name__ == '__main__': + # main parameters for instantiating SEN12MS-CR-TS + root = '/home/data/SEN12MSCRTS' # path to your copy of SEN12MS-CR-TS + split = 'test' # ROI to sample from, belonging to splits [all | train | val | test] + input_t = 3 # number of input time points to sample (irrelevant if choosing sample_type='generic') + region = 'all' # choose the region of data input. [all | africa | america | asiaEast | asiaWest | europa] + sample_type = 'generic' # type of samples returned [cloudy_cloudfree | generic] + import_data_path = None # path to importing the suppl. file specifying what time points to load for input and output, e.g. os.path.join(os.getcwd(), 'util', '3_test_s2cloudless_mask.npy') + export_data_path = os.path.join(dirname, 'precomputed') # e.g. ...'/3_all_train_vary_s2cloudless_mask.npy' + vary = 'random' if split!='test' else 'fixed' # whether to vary samples across epoch or not + n_epochs = 1 if vary=='fixed' or sample_type=='generic' else 30 # if not varying dates across epochs, then a single epoch is sufficient + max_samples = int(1e9) + + shuffle = False + if export_data_path is not None: # if exporting data indices to file then need to disable DataLoader shuffling, else pdx are not sorted (they may still be shuffled when importing) + shuffle = False # ---for importing, shuffling may change the order from that of the exported file (which may or may not be desired) + + sen12mscrts = SEN12MSCRTS(root, split=split, sample_type=sample_type, n_input_samples=input_t, region=region, sampler=vary, import_data_path=import_data_path) + # instantiate dataloader, note: worker_init_fn is needed to get reproducible random samples across runs if vary_samples=True + # note: if using 'export_data_path' then keep batch_size at 1 (unless moving data writing out of dataloader) + # and shuffle=False (processes patches in order, but later imports can still shuffle this) + dataloader = torch.utils.data.DataLoader(sen12mscrts, batch_size=1, shuffle=shuffle, worker_init_fn=seed_worker, generator=g, num_workers=0) + + if export_data_path is not None: + data_pairs = {} # collect pre-computed dates in a dict to be exported + epoch_count = 0 # count, for loading time points that vary across epochs + collect_var = [] # collect variance across S2 intensities + + # iterate over data to pre-compute indices for e.g. training or testing + start_timer = time.time() + for epoch in range(1, n_epochs + 1): + print(f'\nCurating indices for {epoch}. epoch.') + for pdx, patch in enumerate(tqdm(dataloader)): + # stop sampling when sample count is exceeded + if pdx>=max_samples: break + + if sample_type == 'generic': + # collect variances in all samples' S2 intensities, finally compute grand average variance + collect_var.append(torch.stack(patch['S2']).var()) + + if export_data_path is not None: + if sample_type == 'cloudy_cloudfree': + # compute epoch-sensitive index, such that exported dates can differ across epochs + adj_pdx = epoch_count*dataloader.dataset.__len__() + pdx + # performs repeated writing to file, only use this for processes dedicated for exporting + # and if so, only use a single thread of workers (--num_threads 1), this ain't thread-safe + data_pairs[adj_pdx] = {'input': patch['input']['idx'], 'target': patch['target']['idx'], + 'coverage': {'input': patch['input']['coverage'], + 'output': patch['output']['coverage']}, + 'paths': {'input': {'S1': pathify(patch['input']['S1 path']), + 'S2': pathify(patch['input']['S2 path'])}, + 'output': {'S1': pathify(patch['target']['S1 path']), + 'S2': pathify(patch['target']['S2 path'])}}} + elif sample_type == 'generic': + # performs repeated writing to file, only use this for processes dedicated for exporting + # and if so, only use a single thread of workers (--num_threads 1), this ain't thread-safe + data_pairs[pdx] = {'coverage': patch['coverage'], + 'paths': {'S1': pathify(patch['S1 path']), + 'S2': pathify(patch['S2 path'])}} + if sample_type == 'generic': + # export collected dates + # eiter do this here after each epoch or after all epochs + if export_data_path is not None: + ds = dataloader.dataset + if os.path.isdir(export_data_path): + export_here = os.path.join(export_data_path, f'{sample_type}_{input_t}_{split}_{region}_{ds.cloud_masks}.npy') + else: + export_here = export_data_path + np.save(export_here, data_pairs) + print(f'\nEpoch {epoch_count+1}/{n_epochs}: Exported pre-computed dates to {export_here}') + + # bookkeeping at the end of epoch + epoch_count += 1 + + print(f'The grand average variance of S2 samples in the {split} split is: {torch.mean(torch.tensor(collect_var))}') + + if export_data_path is not None: print('Completed exporting data.') + + # benchmark speed of dataloader when (not) using 'import_data_path' flag + elapsed = time.time() - start_timer + print(f'Elapsed time is {elapsed}') \ No newline at end of file diff --git a/UnCRtainTS/util/pytorch_ssim/__init__.py b/UnCRtainTS/util/pytorch_ssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..738e8038942e85592aa7c88070fc6e83e9f9c776 --- /dev/null +++ b/UnCRtainTS/util/pytorch_ssim/__init__.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/UnCRtainTS/util/pytorch_ssim/__pycache__/__init__.cpython-311.pyc b/UnCRtainTS/util/pytorch_ssim/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17c1e135bd2f1fc2c494bb7901a5afa35777a86c Binary files /dev/null and b/UnCRtainTS/util/pytorch_ssim/__pycache__/__init__.cpython-311.pyc differ diff --git a/UnCRtainTS/util/utils.py b/UnCRtainTS/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be82042d39973b83d984a5fc75765c4c60a3671e --- /dev/null +++ b/UnCRtainTS/util/utils.py @@ -0,0 +1,116 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, method, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + # no need to do anything if image_numpy is 3-dimensiona already but for the other dimensions ... + + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) # triple channel + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 + + if image_numpy.shape[0] == 13 or image_numpy.shape[0] == 4: # 13 bands multispectral (or 4 bands NIR) to RGB + # RGB bands are [3, 2, 1] + image_numpy = image_numpy[[3, 2, 1], ...] + + # method is either 'resnet' (if opt.alter_initial_mode) or 'default' + if method == 'default': # re-normalize from [-1,+1] to [0,+1] + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + elif method == 'resnet': # re-normalize from [0, 5] to [0,+1] + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) / 5.0 * 255.0 + + if image_numpy.shape[0] == 2: # (VV,VH) SAR to RGB (just taking VV band) + image_numpy = np.tile(image_numpy[[0]], (3, 1, 1)) + if method == 'default': # re-normalize from [-1,+1] to [0,+1] + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + elif method == 'resnet': # re-normalize from [0, 2] to [0,+1] + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) / 2.0 * 255.0 + # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) \ No newline at end of file