diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000000000000000000000000000000000000..a790e320464ebc778ca07f5bcd826a9c8412ed0e
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,6 @@
+# example of file for storing private and user specific environment variables, like keys or system paths
+# rename it to ".env" (excluded from version control by default)
+# .env is loaded by train.py automatically
+# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
+
+MY_VAR="/home/user/my/system/path"
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..04a06484441a3d09afd793ef8a7107931de8e06f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,154 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+### VisualStudioCode
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+*.code-workspace
+**/.vscode
+
+# JetBrains
+.idea/
+
+# Data & Models
+*.h5
+*.tar
+*.tar.gz
+
+# Lightning-Hydra-Template
+configs/local/default.yaml
+/data/
+/logs/
+.env
+
+# Aim logging
+.aim
diff --git a/.netrc b/.netrc
new file mode 100644
index 0000000000000000000000000000000000000000..b48266ac5f9c056fc0d572e579d6a53466300350
--- /dev/null
+++ b/.netrc
@@ -0,0 +1,3 @@
+machine api.wandb.ai
+ login user
+ password 76211aa17d75da9ddab7b8cba5743454194fe1d5
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee45ce1946f075adb092b2d574abcbdb96169984
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,147 @@
+default_language_version:
+ python: python3
+
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ # list of supported hooks: https://pre-commit.com/hooks.html
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-docstring-first
+ - id: check-yaml
+ - id: debug-statements
+ - id: detect-private-key
+ - id: check-executables-have-shebangs
+ - id: check-toml
+ - id: check-case-conflict
+ - id: check-added-large-files
+
+ # python code formatting
+ - repo: https://github.com/psf/black
+ rev: 23.1.0
+ hooks:
+ - id: black
+ args: [--line-length, "99"]
+
+ # python import sorting
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ args: ["--profile", "black", "--filter-files"]
+
+ # python upgrading syntax to newer version
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.3.1
+ hooks:
+ - id: pyupgrade
+ args: [--py38-plus]
+
+ # python docstring formatting
+ - repo: https://github.com/myint/docformatter
+ rev: v1.7.4
+ hooks:
+ - id: docformatter
+ args:
+ [
+ --in-place,
+ --wrap-summaries=99,
+ --wrap-descriptions=99,
+ --style=sphinx,
+ --black,
+ ]
+
+ # python docstring coverage checking
+ - repo: https://github.com/econchick/interrogate
+ rev: 1.5.0 # or master if you're bold
+ hooks:
+ - id: interrogate
+ args:
+ [
+ --verbose,
+ --fail-under=80,
+ --ignore-init-module,
+ --ignore-init-method,
+ --ignore-module,
+ --ignore-nested-functions,
+ -vv,
+ ]
+
+ # python check (PEP8), programming errors and code complexity
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ args:
+ [
+ "--extend-ignore",
+ "E203,E402,E501,F401,F841,RST2,RST301",
+ "--exclude",
+ "logs/*,data/*",
+ ]
+ additional_dependencies: [flake8-rst-docstrings==0.3.0]
+
+ # python security linter
+ - repo: https://github.com/PyCQA/bandit
+ rev: "1.7.5"
+ hooks:
+ - id: bandit
+ args: ["-s", "B101"]
+
+ # yaml formatting
+ - repo: https://github.com/pre-commit/mirrors-prettier
+ rev: v3.0.0-alpha.6
+ hooks:
+ - id: prettier
+ types: [yaml]
+ exclude: "environment.yaml"
+
+ # shell scripts linter
+ - repo: https://github.com/shellcheck-py/shellcheck-py
+ rev: v0.9.0.2
+ hooks:
+ - id: shellcheck
+
+ # md formatting
+ - repo: https://github.com/executablebooks/mdformat
+ rev: 0.7.16
+ hooks:
+ - id: mdformat
+ args: ["--number"]
+ additional_dependencies:
+ - mdformat-gfm
+ - mdformat-tables
+ - mdformat_frontmatter
+ # - mdformat-toc
+ # - mdformat-black
+
+ # word spelling linter
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.2.4
+ hooks:
+ - id: codespell
+ args:
+ - --skip=logs/**,data/**,*.ipynb
+ # - --ignore-words-list=abc,def
+
+ # jupyter notebook cell output clearing
+ - repo: https://github.com/kynan/nbstripout
+ rev: 0.6.1
+ hooks:
+ - id: nbstripout
+
+ # jupyter notebook linting
+ - repo: https://github.com/nbQA-dev/nbQA
+ rev: 1.6.3
+ hooks:
+ - id: nbqa-black
+ args: ["--line-length=99"]
+ - id: nbqa-isort
+ args: ["--profile=black"]
+ - id: nbqa-flake8
+ args:
+ [
+ "--extend-ignore=E203,E402,E501,F401,F841",
+ "--exclude=logs/*,data/*",
+ ]
diff --git a/.project-root b/.project-root
new file mode 100644
index 0000000000000000000000000000000000000000..63eab774b9e36aa1a46cbd31b59cbd373bc5477f
--- /dev/null
+++ b/.project-root
@@ -0,0 +1,2 @@
+# this file is required for inferring the project root directory
+# do not delete
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..38184df93ea2c09f6d527abbb7f7c804b014284c
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,30 @@
+
+help: ## Show help
+ @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
+
+clean: ## Clean autogenerated files
+ rm -rf dist
+ find . -type f -name "*.DS_Store" -ls -delete
+ find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
+ find . | grep -E ".pytest_cache" | xargs rm -rf
+ find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
+ rm -f .coverage
+
+clean-logs: ## Clean logs
+ rm -rf logs/**
+
+format: ## Run pre-commit hooks
+ pre-commit run -a
+
+sync: ## Merge changes from main branch to your current branch
+ git pull
+ git pull origin main
+
+test: ## Run not slow tests
+ pytest -k "not slow"
+
+test-full: ## Run all tests
+ pytest
+
+train: ## Train the model
+ python src/train.py
diff --git a/app.py b/app.py
index 9730b4f8c724389a10af89915d5cae7d59f2f3c7..bf716a2f7485ab8544195ff5c40df7baca596c77 100644
--- a/app.py
+++ b/app.py
@@ -16,10 +16,10 @@ from albumentations.pytorch.transforms import ToTensorV2
from src.models.components.cdnetv1 import CDnetV1
from src.models.components.cdnetv2 import CDnetV2
-from src.models.components.dual_branch import Dual_Branch
-from src.models.components.hrcloud import HRcloudNet
+from src.models.components.dbnet import DBNet
+from src.models.components.hrcloudnet import HRCloudNet
from src.models.components.mcdnet import MCDNet
-from src.models.components.scnn import SCNNNet
+from src.models.components.scnn import SCNN
class Application:
@@ -28,10 +28,10 @@ class Application:
self.models = {
"cdnetv1": CDnetV1(num_classes=2).to(self.device),
"cdnetv2": CDnetV2(num_classes=2).to(self.device),
- "hrcloud": HRcloudNet(num_classes=2).to(self.device),
+ "hrcloudnet": HRCloudNet(num_classes=2).to(self.device),
"mcdnet": MCDNet(in_channels=3, num_classes=2).to(self.device),
- "scnn": SCNNNet(num_classes=2).to(self.device),
- "dbnet": Dual_Branch(img_size=256, in_channels=3, num_classes=2).to(
+ "scnn": SCNN(num_classes=2).to(self.device),
+ "dbnet": DBNet(img_size=256, in_channels=3, num_classes=2).to(
self.device
),
}
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56bf7f4aa4906bc0f997132708cc0826c198e4aa
--- /dev/null
+++ b/configs/__init__.py
@@ -0,0 +1 @@
+# this file is needed here to include configs when building project as a package
diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c9bf2fb8e6846c55916653a7b520e2cd624eef35
--- /dev/null
+++ b/configs/callbacks/default.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - model_checkpoint
+ - early_stopping
+ - model_summary
+ - rich_progress_bar
+ - _self_
+
+model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/acc"
+ mode: "max"
+ save_last: True
+ auto_insert_metric_name: False
+
+early_stopping:
+ monitor: "val/acc"
+ patience: 100
+ mode: "max"
+
+model_summary:
+ max_depth: -1
diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c826c8d58651a5e2c7cca0e99948a9b6ccabccf3
--- /dev/null
+++ b/configs/callbacks/early_stopping.yaml
@@ -0,0 +1,15 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
+
+early_stopping:
+ _target_: lightning.pytorch.callbacks.EarlyStopping
+ monitor: ??? # quantity to be monitored, must be specified !!!
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
+ patience: 3 # number of checks with no improvement after which training will be stopped
+ verbose: False # verbosity mode
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bf946e88b1ecfaf96efa91428e4f38e17267b25f
--- /dev/null
+++ b/configs/callbacks/model_checkpoint.yaml
@@ -0,0 +1,17 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
+
+model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: null # directory to save the model file
+ filename: null # checkpoint filename
+ monitor: null # name of the logged metric which determines when model is improving
+ verbose: False # verbosity mode
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 1 # save k best models (determined by above metric)
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
+ save_weights_only: False # if True, then only the model’s weights will be saved
+ every_n_train_steps: null # number of training steps between checkpoints
+ train_time_interval: null # checkpoints are monitored at the specified time interval
+ every_n_epochs: null # number of epochs between checkpoints
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b75981d8cd5d73f61088d80495dc540274bca3d1
--- /dev/null
+++ b/configs/callbacks/model_summary.yaml
@@ -0,0 +1,5 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
+
+model_summary:
+ _target_: lightning.pytorch.callbacks.RichModelSummary
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..de6f1ccb11205a4db93645fb6f297e50205de172
--- /dev/null
+++ b/configs/callbacks/rich_progress_bar.yaml
@@ -0,0 +1,4 @@
+# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
+
+rich_progress_bar:
+ _target_: lightning.pytorch.callbacks.RichProgressBar
diff --git a/configs/data/CloudSEN12/README.md b/configs/data/CloudSEN12/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..738132e3f672c3222181b8834f4f6e4d8efd398f
--- /dev/null
+++ b/configs/data/CloudSEN12/README.md
@@ -0,0 +1,52 @@
+# CloudSEN12
+
+> [CloudSEN12, a global dataset for semantic understanding of cloud and cloud shadow in Sentinel-2](https://www.nature.com/articles/s41597-022-01878-2)
+
+## Introduction
+
+- [Official Site](https://cloudsen12.github.io/download.html)
+- [Paper Download](https://www.nature.com/articles/s41597-022-01878-2.pdf)
+- Data Download: [Hugging Face](https://huggingface.co/datasets/csaybar/CloudSEN12-high)
+
+## Abstract
+
+Accurately characterizing clouds and their shadows is a long-standing problem in the Earth Observation community. Recent works showcase the necessity to improve cloud detection methods for imagery acquired by the Sentinel-2 satellites. However, the lack of consensus and transparency in existing reference datasets hampers the benchmarking of current cloud detection methods. Exploiting the analysis-ready data offered by the Copernicus program, we created CloudSEN12, a new multi-temporal global dataset to foster research in cloud and cloud shadow detection. CloudSEN12 has 49,400 image patches, including Sentinel-2 level-1C and level-2A multi-spectral data, Sentinel-1 synthetic aperture radar data, auxiliary remote sensing products, different hand-crafted annotations to label the presence of thick and thin clouds and cloud shadows, and the results from eight state-of-the-art cloud detection algorithms. At present, CloudSEN12 exceeds all previous efforts in terms of annotation richness, scene variability, geographic distribution, metadata complexity, quality control, and number of samples.
+
+## Dataset
+
+CloudSEN12 is a LARGE dataset (~1 TB) for cloud semantic understanding that consists of 49,400 image patches (IP) that are evenly spread throughout all continents except Antarctica. Each IP covers 5090 x 5090 meters and contains data from Sentinel-2 levels 1C and 2A, hand-crafted annotations of thick and thin clouds and cloud shadows, Sentinel-1 Synthetic Aperture Radar (SAR), digital elevation model, surface water occurrence, land cover classes, and cloud mask results from six cutting-edge cloud detection algorithms.
+
+
+
+```
+name: CloudSEN12
+source: Sentinel-1,2
+band: 12
+resolution: 10m
+pixel: 512x512
+train: 8490
+val: 535
+test: 975
+disk: (~1 TB)
+annotation:
+ - 0: Clear
+ - 1: Thick cloud
+ - 2: Thin cloud
+ - 3: Cloud shadow
+scene: -
+```
+
+## Citation
+
+```
+@article{cloudsen12,
+ title={CloudSEN12, a global dataset for semantic understanding of cloud and cloud shadow in Sentinel-2},
+ author={Aybar, Cesar and Ysuhuaylas, Luis and Loja, Jhomira and Gonzales, Karen and Herrera, Fernando and Bautista, Lesly and Yali, Roy and Flores, Angie and Diaz, Lissette and Cuenca, Nicole and others},
+ journal={Scientific data},
+ volume={9},
+ number={1},
+ pages={782},
+ year={2022},
+ publisher={Nature Publishing Group UK London}
+}
+```
diff --git a/configs/data/GF12-MS-WHU/README.md b/configs/data/GF12-MS-WHU/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b9489e4d9195cd9c627d063ce0be1616a2914de2
--- /dev/null
+++ b/configs/data/GF12-MS-WHU/README.md
@@ -0,0 +1,72 @@
+# GaoFen12
+
+> [Transferring Deep Models for Cloud Detection in Multisensor Images via Weakly Supervised Learning](https://ieeexplore.ieee.org/document/10436637)
+
+## Introduction
+
+- [Official Site](https://github.com/whu-ZSC/GF1-GF2MS-WHU)
+- [Paper Download](https://zhiweili.net/assets/pdf/2024.2_TGRS_Transferring%20Deep%20Models%20for%20Cloud%20Detection%20in%20Multisensor%20Images%20via%20Weakly%20Supervised%20Learning.pdf)
+- Data Download: [Baidu Disk](https://pan.baidu.com/s/1kBpym0mW_TS9YL1GQ9t8Hw) (password: 9zuf)
+
+## Abstract
+
+Recently, deep learning has been widely used for cloud detection in satellite images; however, due to radiometric and spatial resolution differences in images from different sensors and time-consuming process of manually labeling cloud detection datasets, it is difficult to effectively generalize deep learning models for cloud detection in multisensor images. This article propose a weakly supervised learning method for transferring deep models for cloud detection in multisensor images (TransMCD), which leverages the generalization of deep models and the spectral features of clouds to construct pseudo-label dataset to improve the generalization of models. A deep model is first pretrained using a well-annotated cloud detection dataset, which is used to obtain a rough cloud mask of unlabeled target image. The rough mask can be used to determine the spectral threshold adaptively for cloud segmentation of target image. Block-level pseudo labels with high confidence in target image are selected using the rough mask and spectral mask. Unsupervised segmentation technique is used to construct a high-quality pixel-level pseudo-label dataset. Finally, the pseudo-label dataset is used as supervised information for transferring the pretrained model to target image. The TransMCD method was validated by transferring model trained on 16-m Gaofen-1 wide field of view(WFV)images to 8-m Gaofen-1, 4-m Gaofen-2, and 10-m Sentinel-2 images. The F1-score of the transferred models on target images achieves improvements of 1.23%–9.63% over the pretrained models, which is comparable to the fully-supervised models trained with well-annotated target images, suggesting the efficiency of the TransMCD method for cloud detection in multisensor images.
+
+## Dataset
+
+### GF1MS-WHU Dataset
+
+> The two GF-1 PMS sensors have four MS bands with an 8-m spatial resolution and a panchromatic (PAN) band with a higher spatial resolution of 2 m. The spectral range of the MS bands is identical to that of the WFV sensors. In this study, 141 unlabeled images collected from various regions in China were used as the training data for the proposed method. In addition, 33 labeled images were used as the training data for the fully supervised methods, as well as the validation data for the different methods. The acquisition of the images spanned from June 2014 to December 2020 and encompassed four MS bands in both PMS sensors. Note that Fig. 7 only presents the distribution regions of the labeled images.
+
+```yaml
+name: GF1MS-WHU
+source: GaoFen-1
+band: 4 (MS)
+resolution: 8m (MS), 2m (PAN)
+pixel: 250x250
+train: 6343
+val: -
+test: 4085
+disk: 10.8GB
+annotation:
+ - 0: clear sky
+ - 1: cloud
+scene: [Forest,Urban,Barren,Water,Farmland,Grass,Wetland]
+```
+
+### GF2MS-WHU Dataset
+
+> The GF-2 satellite is configured with two PMS sensors. Each sensor has four MS bands with a 4-m spatial resolution and a PAN band with a 1-m spatial resolution. The GF-2 PMS sensors have the same bandwidth as the GF-1 WFV sensors. In this study, 163 unlabeled images obtained from Hubei, Jilin, and Hainan provinces were used as the training data for the proposed method, and 29 labeled images were used as the training data for the fully supervised methods, as well as the validation data for the different methods. The images were acquired from June 2014 to October 2020 and included four MS bands in both PMS sensors.
+
+```yaml
+name: GF2MS-WHU
+source: GaoFen-2
+band: 4 (MS)
+resolution: 4m (MS), 1m (PAN)
+pixel: 250x250
+train: 14357
+val: -
+test: 7560
+disk: 26.7GB
+annotation:
+ - 0: clear sky
+ - 1: cloud
+scene: [Forest,Urban,Barren,Water,Farmland,Grass,Wetland]
+```
+
+
+## Citation
+
+```bibtex
+@ARTICLE{gaofen12,
+ author={Zhu, Shaocong and Li, Zhiwei and Shen, Huanfeng},
+ journal={IEEE Transactions on Geoscience and Remote Sensing},
+ title={Transferring Deep Models for Cloud Detection in Multisensor Images via Weakly Supervised Learning},
+ year={2024},
+ volume={62},
+ number={},
+ pages={1-18},
+ keywords={Cloud computing;Clouds;Sensors;Predictive models;Supervised learning;Image segmentation;Deep learning;Cloud detection;deep learning;multisensor images;weakly supervised learning},
+ doi={10.1109/TGRS.2024.3358824}
+}
+```
diff --git a/configs/data/L8-Biome/README.md b/configs/data/L8-Biome/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ba834dc4a86721a313f18746482a0d2b30a04364
--- /dev/null
+++ b/configs/data/L8-Biome/README.md
@@ -0,0 +1,56 @@
+# L8-Biome
+
+> [Cloud detection algorithm comparison and validation for operational Landsat data products](https://www.sciencedirect.com/science/article/abs/pii/S0034425717301293)
+
+## Introduction
+
+- [Official Site](https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data)
+- [Paper Download](https://gerslab.cahnr.uconn.edu/wp-content/uploads/sites/2514/2021/06/1-s2.0-S0034425717301293-Steve_Foga_cloud_detection_2017.pdf)
+- Data Download: [USGS](https://landsat.usgs.gov/landsat-8-cloud-cover-assessment-validation-data)
+
+## Abstract
+
+Clouds are a pervasive and unavoidable issue in satellite-borne optical imagery. Accurate, well-documented, and automated cloud detection algorithms are necessary to effectively leverage large collections of remotely sensed data. The Landsat project is uniquely suited for comparative validation of cloud assessment algorithms because the modular architecture of the Landsat ground system allows for quick evaluation of new code, and because Landsat has the most comprehensive manual truth masks of any current satellite data archive. Currently, the Landsat Level-1 Product Generation System (LPGS) uses separate algorithms for determining clouds, cirrus clouds, and snow and/or ice probability on a per-pixel basis. With more bands onboard the Landsat 8 Operational Land Imager (OLI)/Thermal Infrared Sensor (TIRS) satellite, and a greater number of cloud masking algorithms, the U.S. Geological Survey (USGS) is replacing the current cloud masking workflow with a more robust algorithm that is capable of working across multiple Landsat sensors with minimal modification. Because of the inherent error from stray light and intermittent data availability of TIRS, these algorithms need to operate both with and without thermal data. In this study, we created a workflow to evaluate cloud and cloud shadow masking algorithms using cloud validation masks manually derived from both Landsat 7 Enhanced Thematic Mapper Plus (ETM+) and Landsat 8 OLI/TIRS data. We created a new validation dataset consisting of 96 Landsat 8 scenes, representing different biomes and proportions of cloud cover. We evaluated algorithm performance by overall accuracy, omission error, and commission error for both cloud and cloud shadow. We found that CFMask, C code based on the Function of Mask (Fmask) algorithm, and its confidence bands have the best overall accuracy among the many algorithms tested using our validation data. The Artificial Thermal-Automated Cloud Cover Algorithm (AT-ACCA) is the most accurate nonthermal-based algorithm. We give preference to CFMask for operational cloud and cloud shadow detection, as it is derived from a priori knowledge of physical phenomena and is operable without geographic restriction, making it useful for current and future land imaging missions without having to be retrained in a machine-learning environment.
+
+## Dataset
+
+This collection contains 96 Pre-Collection Landsat 8 Operational Land Imager (OLI) Thermal Infrared Sensor (TIRS) terrain-corrected (Level-1T) scenes, displayed in the biomes listed below. Manually generated cloud masks are used to validate cloud cover assessment algorithms, which in turn are intended to compute the percentage of cloud cover in each scene.
+
+
+
+
+```yaml
+name: hrc_whu
+source: Landsat-8 OLI/TIRS
+band: 9
+resolution: 30m
+pixel: ∼7000 × 6000
+train: -
+val: -
+test: -
+disk: 88GB
+annotation:
+ - 0: Fill
+ - 64: Cloud Shadow
+ - 128: Clear
+ - 192: Thin Cloud
+ - 255: Cloud
+scene: [Barren,Forest,Grass/Crops,Shrubland,Snow/Ice,Urban,Water,Wetlands]
+```
+
+## Citation
+
+```bibtex
+@article{l8biome,
+ title = {Cloud detection algorithm comparison and validation for operational Landsat data products},
+ journal = {Remote Sensing of Environment},
+ volume = {194},
+ pages = {379-390},
+ year = {2017},
+ issn = {0034-4257},
+ doi = {https://doi.org/10.1016/j.rse.2017.03.026},
+ url = {https://www.sciencedirect.com/science/article/pii/S0034425717301293},
+ author = {Steve Foga and Pat L. Scaramuzza and Song Guo and Zhe Zhu and Ronald D. Dilley and Tim Beckmann and Gail L. Schmidt and John L. Dwyer and M. {Joseph Hughes} and Brady Laue},
+ keywords = {Landsat, CFMask, Cloud detection, Cloud validation masks, Biome sampling, Data products},
+}
+```
diff --git a/configs/data/celeba.yaml b/configs/data/celeba.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25890188131c0231ff3f155a6a6713082f235777
--- /dev/null
+++ b/configs/data/celeba.yaml
@@ -0,0 +1,8 @@
+_target_: src.data.celeba_datamodule.CelebADataModule
+size: 512 # image size
+test_dataset_size: 3000
+conditions: [] # [] for image-only, ['seg_mask', 'text'] for multi-modal conditions, ['seg_mask'] for segmentation mask only, ['text'] for text only
+batch_size: 2
+num_workers: 8
+pin_memory: False
+persistent_workers: False
\ No newline at end of file
diff --git a/configs/data/hrcwhu/README.md b/configs/data/hrcwhu/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e51ea6785e39c9a6c8fbfaa3b14d38273822e1bd
--- /dev/null
+++ b/configs/data/hrcwhu/README.md
@@ -0,0 +1,56 @@
+# HRC_WHU
+
+> [Deep learning based cloud detection for medium and high resolution remote sensing images of different sensors](https://www.sciencedirect.com/science/article/pii/S0924271619300565)
+
+## Introduction
+
+- [Official Site](http://sendimage.whu.edu.cn/en/hrc_whu/)
+- [Paper Download](http://sendimage.whu.edu.cn/en/wp-content/uploads/2019/03/2019_PHOTO_Zhiwei-Li_Deep-learning-based-cloud-detection-for-medium-and-high-resolution-remote-sensing-images-of-different-sensors.pdf)
+- Data Download: [Baidu Disk](https://pan.baidu.com/s/1thOTKVO2iTAalFAjFI2_ZQ) (password: ihfb) or [Google Drive](https://drive.google.com/file/d/1qqikjaX7tkfOONsF5EtR4vl6J7sToA6p/view?usp=sharing)
+
+## Abstract
+
+Cloud detection is an important preprocessing step for the precise application of optical satellite imagery. In this paper, we propose a deep learning based cloud detection method named multi-scale convolutional feature fusion (MSCFF) for remote sensing images of different sensors. In the network architecture of MSCFF, the symmetric encoder-decoder module, which provides both local and global context by densifying feature maps with trainable convolutional filter banks, is utilized to extract multi-scale and high-level spatial features. The feature maps of multiple scales are then up-sampled and concatenated, and a novel multi-scale feature fusion module is designed to fuse the features of different scales for the output. The two output feature maps of the network are cloud and cloud shadow maps, which are in turn fed to binary classifiers outside the model to obtain the final cloud and cloud shadow mask. The MSCFF method was validated on hundreds of globally distributed optical satellite images, with spatial resolutions ranging from 0.5 to 50 m, including Landsat-5/7/8, Gaofen-1/2/4, Sentinel-2, Ziyuan-3, CBERS-04, Huanjing-1, and collected high-resolution images exported from Google Earth. The experimental results show that MSCFF achieves a higher accuracy than the traditional rule-based cloud detection methods and the state-of-the-art deep learning models, especially in bright surface covered areas. The effectiveness of MSCFF means that it has great promise for the practical application of cloud detection for multiple types of medium and high-resolution remote sensing images. Our established global high-resolution cloud detection validation dataset has been made available online (http://sendimage.whu.edu.cn/en/mscff/).
+
+## Dataset
+
+The high-resolution cloud cover validation dataset was created by the SENDIMAGE Lab in Wuhan University, and has been termed HRC_WHU. The HRC_WHU data comprise 150 high-resolution images acquired with three RGB channels and a resolution varying from 0.5 to 15 m in different global regions. As shown in Fig. 1, the images were collected from Google Earth (Google Inc.) in five main land-cover types, i.e., water, vegetation, urban, snow/ice, and barren. The associated reference cloud masks were digitized by experts in the field of remote sensing image interpretation. The established high-resolution cloud cover validation dataset has been made available online.
+
+
+
+```yaml
+name: hrc_whu
+source: google earth
+band: 3 (rgb)
+resolution: 0.5m-15m
+pixel: 1280x720
+train: 120
+val: null
+test: 30
+disk: 168mb
+annotation:
+ - 0: clear sky
+ - 1: cloud
+scene: [water, vegetation, urban, snow/ice, barren]
+```
+
+## Annotation
+
+In the procedure of delineating the cloud mask for high-resolution imagery, we first stretched the cloudy image to the appropriate contrast in Adobe Photoshop. The lasso tool and magic wand tool were then alternately used to mark the locations of the clouds in the image. The manually labeled reference mask was finally created by assigning the pixel values of cloud and clear sky to 255 and 0, respectively. Note that a tolerance of 5–30 was set when using the magic wand tool, and the lasso tool was used to modify the areas that could not be correctly selected by the magic wand tool. As we did in a previous study (Li et al., 2017), the thin clouds were labeled as cloud if they were visually identifiable and the underlying surface could not be seen clearly. Considering that cloud shadows in high-resolution images are rare and hard to accurately select, only clouds were labeled in the reference masks.
+
+## Citation
+
+```bibtex
+@article{hrc_whu,
+ title = {Deep learning based cloud detection for medium and high resolution remote sensing images of different sensors},
+ journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
+ volume = {150},
+ pages = {197-212},
+ year = {2019},
+ issn = {0924-2716},
+ doi = {https://doi.org/10.1016/j.isprsjprs.2019.02.017},
+ url = {https://www.sciencedirect.com/science/article/pii/S0924271619300565},
+ author = {Zhiwei Li and Huanfeng Shen and Qing Cheng and Yuhao Liu and Shucheng You and Zongyi He},
+ keywords = {Cloud detection, Cloud shadow, Convolutional neural network, Multi-scale, Convolutional feature fusion, MSCFF}
+}
+```
diff --git a/configs/data/hrcwhu/hrcwhu.yaml b/configs/data/hrcwhu/hrcwhu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d8d8fb405c3b22400a55a74f60a29d78ee42354
--- /dev/null
+++ b/configs/data/hrcwhu/hrcwhu.yaml
@@ -0,0 +1,89 @@
+_target_: src.data.hrcwhu_datamodule.HRCWHUDataModule
+root: data/hrcwhu
+train_pipeline:
+ all_transform:
+ _target_: albumentations.Compose
+ transforms:
+ - _target_: albumentations.HorizontalFlip
+ p: 0.5
+ - _target_: albumentations.ShiftScaleRotate
+ p: 1
+ - _target_: albumentations.RandomCrop
+ height: 256
+ width: 256
+ always_apply: true
+ - _target_: albumentations.GaussNoise
+ p: 0.2
+ - _target_: albumentations.Perspective
+ p: 0.5
+ - _target_: albumentations.OneOf
+ transforms:
+ - _target_: albumentations.CLAHE
+ p: 1
+ - _target_: albumentations.RandomGamma
+ p: 1
+ p: 0.9
+
+ - _target_: albumentations.OneOf
+ transforms:
+ - _target_: albumentations.Sharpen
+ p: 1
+ - _target_: albumentations.Blur
+ p: 1
+ - _target_: albumentations.MotionBlur
+ p: 1
+ p: 0.9
+
+ - _target_: albumentations.OneOf
+ transforms:
+ - _target_: albumentations.RandomBrightnessContrast
+ p: 1
+ - _target_: albumentations.HueSaturationValue
+ p: 1
+ p: 0.9
+
+ img_transform:
+ _target_: albumentations.Compose
+ transforms:
+ - _target_: albumentations.ToFloat
+ max_value: 255.0
+ - _target_: albumentations.pytorch.transforms.ToTensorV2
+
+ ann_transform: null
+val_pipeline:
+ all_transform:
+ _target_: albumentations.Compose
+ transforms:
+ - _target_: albumentations.Resize
+ height: 256
+ width: 256
+
+ img_transform:
+ _target_: albumentations.Compose
+ transforms:
+ - _target_: albumentations.ToFloat
+ max_value: 255.0
+ - _target_: albumentations.pytorch.transforms.ToTensorV2
+ ann_transform: null
+
+test_pipeline:
+ all_transform:
+ _target_: albumentations.Compose
+ transforms:
+ - _target_: albumentations.Resize
+ height: 256
+ width: 256
+
+ img_transform:
+ _target_: albumentations.Compose
+ transforms:
+ - _target_: albumentations.ToFloat
+ max_value: 255.0
+ - _target_: albumentations.pytorch.transforms.ToTensorV2
+ ann_transform: null
+
+seed: 42
+batch_size: 8
+num_workers: 8
+pin_memory: True
+persistent_workers: True
\ No newline at end of file
diff --git a/configs/data/mnist.yaml b/configs/data/mnist.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..51bfaff092a1e3fe2551c89dafa7c7b90ebffe40
--- /dev/null
+++ b/configs/data/mnist.yaml
@@ -0,0 +1,6 @@
+_target_: src.data.mnist_datamodule.MNISTDataModule
+data_dir: ${paths.data_dir}
+batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
+train_val_test_split: [55_000, 5_000, 10_000]
+num_workers: 0
+pin_memory: False
diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1886902b39f1be560e314bce7b3778f95b44754c
--- /dev/null
+++ b/configs/debug/default.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+# default debugging setup, runs 1 full epoch
+# other debugging configs can inherit from this one
+
+# overwrite task name so debugging logs are stored in separate folder
+task_name: "debug"
+
+# disable callbacks and loggers during debugging
+callbacks: null
+logger: null
+
+extras:
+ ignore_warnings: False
+ enforce_tags: False
+
+# sets level of all command line loggers to 'DEBUG'
+# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
+hydra:
+ job_logging:
+ root:
+ level: DEBUG
+
+ # use this to also set hydra loggers to 'DEBUG'
+ # verbose: True
+
+trainer:
+ max_epochs: 1
+ accelerator: cpu # debuggers don't like gpus
+ devices: 1 # debuggers don't like multiprocessing
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
+
+data:
+ num_workers: 0 # debuggers don't like multiprocessing
+ pin_memory: False # disable gpu memory pin
diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7f2d34fa37c31017e749d5a4fc5ae6763e688b46
--- /dev/null
+++ b/configs/debug/fdr.yaml
@@ -0,0 +1,9 @@
+# @package _global_
+
+# runs 1 train, 1 validation and 1 test step
+
+defaults:
+ - default
+
+trainer:
+ fast_dev_run: true
diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..514d77fbd1475b03fff0372e3da3c2fa7ea7d190
--- /dev/null
+++ b/configs/debug/limit.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# uses only 1% of the training data and 5% of validation/test data
+
+defaults:
+ - default
+
+trainer:
+ max_epochs: 3
+ limit_train_batches: 0.01
+ limit_val_batches: 0.05
+ limit_test_batches: 0.05
diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9906586a67a12aa81ff69138f589a366dbe2222f
--- /dev/null
+++ b/configs/debug/overfit.yaml
@@ -0,0 +1,13 @@
+# @package _global_
+
+# overfits to 3 batches
+
+defaults:
+ - default
+
+trainer:
+ max_epochs: 20
+ overfit_batches: 3
+
+# model ckpt and early stopping need to be disabled during overfitting
+callbacks: null
diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2bd7da87ae23ed425ace99b09250a76a5634a3fb
--- /dev/null
+++ b/configs/debug/profiler.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# runs with execution time profiling
+
+defaults:
+ - default
+
+trainer:
+ max_epochs: 1
+ profiler: "simple"
+ # profiler: "advanced"
+ # profiler: "pytorch"
diff --git a/configs/eval.yaml b/configs/eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..be312992b2a486b04d83a54dbd8f670d94979709
--- /dev/null
+++ b/configs/eval.yaml
@@ -0,0 +1,18 @@
+# @package _global_
+
+defaults:
+ - _self_
+ - data: mnist # choose datamodule with `test_dataloader()` for evaluation
+ - model: mnist
+ - logger: null
+ - trainer: default
+ - paths: default
+ - extras: default
+ - hydra: default
+
+task_name: "eval"
+
+tags: ["dev"]
+
+# passing checkpoint path is necessary for evaluation
+ckpt_path: ???
diff --git a/configs/experiment/cnn.yaml b/configs/experiment/cnn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e429cdcb29db6d31595acc75fa5af51972f42d1
--- /dev/null
+++ b/configs/experiment/cnn.yaml
@@ -0,0 +1,56 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: mnist
+ - override /model: cnn
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["mnist", "cnn"]
+
+seed: 42
+
+trainer:
+ min_epochs: 10
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ devices: 1
+
+data:
+ batch_size: 128
+ train_val_test_split: [55_000, 5_000, 10_000]
+ num_workers: 31
+ pin_memory: False
+ persistent_workers: False
+
+model:
+ net:
+ dim: 32
+
+logger:
+ wandb:
+ project: "mnist"
+ name: "cnn"
+ aim:
+ experiment: "cnn"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/acc"
+ mode: "max"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/acc"
+ patience: 100
+ mode: "max"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_cdnetv1.yaml b/configs/experiment/hrcwhu_cdnetv1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bce0bdcb8b33c209a21287bc1cbbb89f83f7c7db
--- /dev/null
+++ b/configs/experiment/hrcwhu_cdnetv1.yaml
@@ -0,0 +1,47 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: cdnetv1/cdnetv1
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "cdnetv1"]
+
+seed: 42
+
+
+ # scheduler:
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ # _partial_: true
+ # mode: min
+ # factor: 0.1
+ # patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "cdnetv1"
+ aim:
+ experiment: "hrcwhu_cdnetv1"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_cdnetv2.yaml b/configs/experiment/hrcwhu_cdnetv2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2b21dc6d872411bdf06db64300725ace1fa4f0a0
--- /dev/null
+++ b/configs/experiment/hrcwhu_cdnetv2.yaml
@@ -0,0 +1,47 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: cdnetv2/cdnetv2
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "cdnetv2"]
+
+seed: 42
+
+
+ # scheduler:
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ # _partial_: true
+ # mode: min
+ # factor: 0.1
+ # patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "cdnetv2"
+ aim:
+ experiment: "hrcwhu_cdnetv2"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_dbnet.yaml b/configs/experiment/hrcwhu_dbnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..388af5a19489a1e09bce6f6c397d8d8525ff634c
--- /dev/null
+++ b/configs/experiment/hrcwhu_dbnet.yaml
@@ -0,0 +1,48 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: dbnet/dbnet
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "dbnet"]
+
+seed: 42
+
+
+
+ # scheduler:
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ # _partial_: true
+ # mode: min
+ # factor: 0.1
+ # patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "dbnet"
+ aim:
+ experiment: "hrcwhu_dbnet"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_hrcloud.yaml b/configs/experiment/hrcwhu_hrcloud.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13431da01c08bf765b800a1deba6659887145245
--- /dev/null
+++ b/configs/experiment/hrcwhu_hrcloud.yaml
@@ -0,0 +1,47 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: hrcloudnet/hrcloudnet
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "hrcloud"]
+
+seed: 42
+
+
+ # scheduler:
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ # _partial_: true
+ # mode: min
+ # factor: 0.1
+ # patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "hrcloud"
+ aim:
+ experiment: "hrcwhu_hrcloud"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_mcdnet.yaml b/configs/experiment/hrcwhu_mcdnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cf8c4cb41259143277f6d40b3d7ce46071ebc96f
--- /dev/null
+++ b/configs/experiment/hrcwhu_mcdnet.yaml
@@ -0,0 +1,47 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: mcdnet/mcdnet
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "mcdnet"]
+
+seed: 42
+
+
+# scheduler:
+# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+# _partial_: true
+# mode: min
+# factor: 0.1
+# patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "mcdnet"
+ aim:
+ experiment: "hrcwhu_mcdnet"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_scnn.yaml b/configs/experiment/hrcwhu_scnn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..89145349426347d5c22762778d06fbacde21ce74
--- /dev/null
+++ b/configs/experiment/hrcwhu_scnn.yaml
@@ -0,0 +1,47 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: scnn/scnn
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "scnn"]
+
+seed: 42
+
+
+ # scheduler:
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ # _partial_: true
+ # mode: min
+ # factor: 0.1
+ # patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "scnn"
+ aim:
+ experiment: "hrcwhu_scnn"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/hrcwhu_unet.yaml b/configs/experiment/hrcwhu_unet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5247691cc1b5d258a7aa9e335c19f021670215b5
--- /dev/null
+++ b/configs/experiment/hrcwhu_unet.yaml
@@ -0,0 +1,68 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcwhu/hrcwhu
+ - override /model: null
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "unet"]
+
+seed: 42
+
+
+
+model:
+ _target_: src.models.base_module.BaseLitModule
+
+ net:
+ _target_: src.models.components.unet.UNet
+ in_channels: 3
+ out_channels: 2
+
+ num_classes: 2
+
+ criterion:
+ _target_: torch.nn.CrossEntropyLoss
+
+ optimizer:
+ _target_: torch.optim.SGD
+ _partial_: true
+ lr: 0.1
+
+ scheduler: null
+
+ # scheduler:
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ # _partial_: true
+ # mode: min
+ # factor: 0.1
+ # patience: 10
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "unet"
+ aim:
+ experiment: "hrcwhu_unet"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 10
+ mode: "min"
\ No newline at end of file
diff --git a/configs/experiment/lnn.yaml b/configs/experiment/lnn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..20c0791bfc9504f3f159416f5f5060f28f530a9b
--- /dev/null
+++ b/configs/experiment/lnn.yaml
@@ -0,0 +1,57 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: mnist
+ - override /model: lnn
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["mnist", "lnn"]
+
+seed: 42
+
+trainer:
+ min_epochs: 10
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ devices: 1
+
+data:
+ batch_size: 128
+ train_val_test_split: [55_000, 5_000, 10_000]
+ num_workers: 31
+ pin_memory: False
+ persistent_workers: False
+
+model:
+ net:
+ _target_: src.models.components.lnn.LNN
+ dim: 32
+
+logger:
+ wandb:
+ project: "mnist"
+ name: "lnn"
+ aim:
+ experiment: "lnn"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/acc"
+ mode: "max"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/acc"
+ patience: 100
+ mode: "max"
\ No newline at end of file
diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0
--- /dev/null
+++ b/configs/extras/default.yaml
@@ -0,0 +1,8 @@
+# disable python warnings if they annoy you
+ignore_warnings: False
+
+# ask user for tags if none are provided in the config
+enforce_tags: True
+
+# pretty print config tree at the start of the run using Rich library
+print_config: True
diff --git a/configs/hparams_search/mnist_optuna.yaml b/configs/hparams_search/mnist_optuna.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1391183ebcdec3d8f5eb61374e0719d13c7545da
--- /dev/null
+++ b/configs/hparams_search/mnist_optuna.yaml
@@ -0,0 +1,52 @@
+# @package _global_
+
+# example hyperparameter optimization of some experiment with Optuna:
+# python train.py -m hparams_search=mnist_optuna experiment=example
+
+defaults:
+ - override /hydra/sweeper: optuna
+
+# choose metric which will be optimized by Optuna
+# make sure this is the correct name of some metric logged in lightning module!
+optimized_metric: "val/acc_best"
+
+# here we define Optuna hyperparameter search
+# it optimizes for value returned from function with @hydra.main decorator
+# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
+hydra:
+ mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
+
+ sweeper:
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
+
+ # storage URL to persist optimization results
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
+ storage: null
+
+ # name of the study to persist optimization results
+ study_name: null
+
+ # number of parallel workers
+ n_jobs: 1
+
+ # 'minimize' or 'maximize' the objective
+ direction: maximize
+
+ # total number of runs that will be executed
+ n_trials: 20
+
+ # choose Optuna hyperparameter sampler
+ # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
+ sampler:
+ _target_: optuna.samplers.TPESampler
+ seed: 1234
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
+
+ # define hyperparameter search space
+ params:
+ model.optimizer.lr: interval(0.0001, 0.1)
+ data.batch_size: choice(32, 64, 128, 256)
+ model.net.lin1_size: choice(64, 128, 256)
+ model.net.lin2_size: choice(64, 128, 256)
+ model.net.lin3_size: choice(32, 64, 128, 256)
diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aace38d320b808a7e53ea4ee230992e5abe804e9
--- /dev/null
+++ b/configs/hydra/default.yaml
@@ -0,0 +1,19 @@
+# https://hydra.cc/docs/configure_hydra/intro/
+
+# enable color logging
+defaults:
+ - override hydra_logging: colorlog
+ - override job_logging: colorlog
+
+# output directory, generated dynamically on each run
+run:
+ dir: ${paths.log_dir}/${task_name}/runs/${logger.aim.experiment}/${now:%Y-%m-%d}_${now:%H-%M-%S}
+sweep:
+ dir: ${paths.log_dir}/${task_name}/multiruns/${logger.aim.experiment}/${now:%Y-%m-%d}_${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+
+job_logging:
+ handlers:
+ file:
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
+ filename: ${hydra.runtime.output_dir}/${task_name}.log
diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8f9f6adad7feb2780c2efd5ddb0ed053621e05f8
--- /dev/null
+++ b/configs/logger/aim.yaml
@@ -0,0 +1,28 @@
+# https://aimstack.io/
+
+# example usage in lightning module:
+# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
+
+# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
+# `aim up`
+
+aim:
+ _target_: aim.pytorch_lightning.AimLogger
+ repo: ${paths.root_dir} # .aim folder will be created here
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
+
+ # aim allows to group runs under experiment name
+ experiment: null # any string, set to "default" if not specified
+
+ train_metric_prefix: "train/"
+ val_metric_prefix: "val/"
+ test_metric_prefix: "test/"
+
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
+
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
+ log_system_params: true
+
+ # enable/disable tracking console logs (default value is true)
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e0789274e2137ee6c97ca37a5d56c2b8abaf0aaa
--- /dev/null
+++ b/configs/logger/comet.yaml
@@ -0,0 +1,12 @@
+# https://www.comet.ml
+
+comet:
+ _target_: lightning.pytorch.loggers.comet.CometLogger
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
+ save_dir: "${paths.output_dir}"
+ project_name: "lightning-hydra-template"
+ rest_api_key: null
+ # experiment_name: ""
+ experiment_key: null # set to resume experiment
+ offline: False
+ prefix: ""
diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fa028e9c146430c319101ffdfce466514338591c
--- /dev/null
+++ b/configs/logger/csv.yaml
@@ -0,0 +1,7 @@
+# csv logger built in lightning
+
+csv:
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
+ save_dir: "${paths.output_dir}"
+ name: "csv/"
+ prefix: ""
diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd586800bdccb4e8f4b0236a181b7ddd756ba9ab
--- /dev/null
+++ b/configs/logger/many_loggers.yaml
@@ -0,0 +1,9 @@
+# train with many loggers at once
+
+defaults:
+ # - comet
+ - csv
+ # - mlflow
+ # - neptune
+ - tensorboard
+ - wandb
diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f8fb7e685fa27fc8141387a421b90a0b9b492d9e
--- /dev/null
+++ b/configs/logger/mlflow.yaml
@@ -0,0 +1,12 @@
+# https://mlflow.org
+
+mlflow:
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
+ # experiment_name: ""
+ # run_name: ""
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
+ tags: null
+ # save_dir: "./mlruns"
+ prefix: ""
+ artifact_location: null
+ # run_id: ""
diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8233c140018ecce6ab62971beed269991d31c89b
--- /dev/null
+++ b/configs/logger/neptune.yaml
@@ -0,0 +1,9 @@
+# https://neptune.ai
+
+neptune:
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
+ project: username/lightning-hydra-template
+ # name: ""
+ log_model_checkpoints: True
+ prefix: ""
diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2bd31f6d8ba68d1f5c36a804885d5b9f9c1a9302
--- /dev/null
+++ b/configs/logger/tensorboard.yaml
@@ -0,0 +1,10 @@
+# https://www.tensorflow.org/tensorboard/
+
+tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.output_dir}/tensorboard/"
+ name: null
+ log_graph: False
+ default_hp_metric: True
+ prefix: ""
+ # version: ""
diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ece165889b3d0d9dc750a8f3c7454188cfdf12b7
--- /dev/null
+++ b/configs/logger/wandb.yaml
@@ -0,0 +1,16 @@
+# https://wandb.ai
+
+wandb:
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # name: "" # name of the run (normally generated by wandb)
+ save_dir: "${paths.output_dir}"
+ offline: False
+ id: null # pass correct id to resume experiment!
+ anonymous: null # enable anonymous logging
+ project: "lightning-hydra-template"
+ log_model: False # upload lightning ckpts
+ prefix: "" # a string to put at the beginning of metric keys
+ # entity: "" # set to name of your wandb team
+ group: ""
+ tags: []
+ job_type: ""
diff --git a/configs/model/cdnetv1/README.md b/configs/model/cdnetv1/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..05bf8cdddde31a57408ac919b3de6fdd09bdd832
--- /dev/null
+++ b/configs/model/cdnetv1/README.md
@@ -0,0 +1,117 @@
+# CDnet: CNN-Based Cloud Detection for Remote Sensing Imagery
+
+> [CDnet: CNN-Based Cloud Detection for Remote Sensing Imagery](https://ieeexplore.ieee.org/document/8681238)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Cloud detection is one of the important tasks for remote sensing image (RSI) preprocessing. In this paper, we utilize the thumbnail (i.e., preview image) of RSI, which contains the information of original multispectral or panchromatic imagery, to extract cloud mask efficiently. Compared with detection cloud mask from original RSI, it is more challenging to detect cloud mask using thumbnails due to the loss of resolution and spectrum information. To tackle this problem, we propose a cloud detection neural network (CDnet) with an encoder–decoder structure, a feature pyramid module (FPM), and a boundary refinement (BR) block. The FPM extracts the multiscale contextual information without the loss of resolution and coverage; the BR block refines object boundaries; and the encoder–decoder structure gradually recovers segmentation results with the same size as input image. Experimental results on the ZY-3 satellite thumbnails cloud cover validation data set and two other validation data sets (GF-1 WFV Cloud and Cloud Shadow Cover Validation Data and Landsat-8 Cloud Cover Assessment Validation Data) demonstrate that the proposed method achieves accurate detection accuracy and outperforms several state-of-the-art methods.
+
+
+
+
+
+

+
+
+## Results and models
+
+### CLOUD EXTRACTION ACCURACY (%)
+
+
+| Method | OA | MIoU | Kappa | PA | UA |
+|----------|-------|-------|-------|-------|-------|
+| CDnet(ASPP+GAP) | 95.41 | 89.38 | 82.05 | 87.82 | 89.85 |
+| CDnet(FPM) | 96.47 | 91.70 | 85.06 | 89.75 | 90.41 |
+
+
+### CLOUD EXTRACTION ACCURACY (%) FOR MODULES AND VARIANTS OF THE CDNET
+
+
+
+| Method | OA | MIoU | Kappa | PA | UA |
+|----------|-------|-------|-------|-------|-------|
+| ResNet50 | 91.13 | 82.83 | 73.38 | 81.99 | 80.34 |
+| MRN* | 93.03 | 85.24 | 77.51 | 82.59 | 82.82 |
+| MRN+FPM | 93.89 | 88.50 | 81.82 | 87.10 | 85.51 |
+| MRN+FPM+BR| 94.31 | 88.97 | 82.59 | 87.12 | 87.04 |
+| CDnet-FPM | 93.14 | 88.14 | 80.44 | 87.64 | 84.46 |
+| CDnet-BR | 95.04 | 89.63 | 83.78 | 87.36 | 88.67 |
+| CDnet-FPM-BR| 93.10 | 87.91 | 80.01 | 87.01 | 83.84 |
+| CDnet-A | 94.84 | 89.41 | 82.91 | 87.32 | 88.07 |
+| CDnet-B | 95.27 | 90.51 | 84.01 | 88.97 | 89.71 |
+| CDnet-C | 96.09 | 90.73 | 84.27 | 88.74 | 90.28 |
+| CDnet | 96.47 | 91.70 | 85.06 | 89.75 | 90.41 |
+
+MRN stands for the modified ResNet-50.
+
+
+
+### CLOUD EXTRACTION ACCURACY (%)
+
+
+| Method | OA | MIoU | Kappa | PA | UA |
+|----------|-------|-------|-------|-------|-------|
+| Maxlike | 77.73 | 66.16 | 53.55 | 91.30 | 54.98 |
+| SVM | 78.21 | 66.79 | 54.87 | 91.77 | 56.37 |
+| L-unet | 86.51 | 73.67 | 63.79 | 83.15 | 64.79 |
+| FCN-8 | 90.53 | 81.08 | 68.08 | 82.91 | 78.87 |
+| MVGG-16 | 92.73 | 86.65 | 78.94 | 88.12 | 81.84 |
+| DPN | 93.11 | 86.73 | 79.05 | 87.68 | 83.96 |
+| DeeplabV2 | 93.36 | 87.56 | 79.12 | 87.50 | 84.65 |
+| PSPnet | 94.24 | 88.37 | 81.41 | 86.67 | 89.17 |
+| DeeplabV3 | 95.03 | 88.74 | 81.53 | 87.63 | 89.72 |
+| DeeplabV3+| 96.01 | 90.45 | 83.92 | 88.47 | 90.03 |
+| CDnet | 96.47 | 91.70 | 85.06 | 89.75 | 90.41 |
+
+
+### Cloud Extraction Accuracy (%) of GF-1 Satellite Imagery
+
+| Method | OA | MIoU | Kappa | PA | UA |
+|----------|-------|-------|-------|-------|-------|
+| MFC | 92.36 | 80.32 | 74.64 | 83.58 | 75.32 |
+| L-unet | 92.44 | 82.39 | 76.26 | 87.61 | 74.98 |
+| FCN-8 | 92.61 | 82.71 | 76.45 | 87.45 | 75.61 |
+| MVGG-16 | 93.07 | 86.17 | 77.13 | 87.68 | 79.50 |
+| DPN | 93.19 | 86.32 | 77.25 | 86.85 | 80.93 |
+| DeeplabV2 | 95.07 | 87.00 | 80.07 | 86.60 | 82.18 |
+| PSPnet | 95.30 | 87.45 | 80.74 | 85.87 | 83.27 |
+| DeeplabV3 | 95.95 | 88.13 | 81.05 | 86.36 | 88.72 |
+| DeeplabV3+| 96.18 | 89.11 | 82.31 | 87.37 | 89.05 |
+| CDnet | 96.73 | 89.83 | 83.23 | 87.94 | 89.60 |
+
+### Cloud Extraction Accuracy (%) of Landsat-8 Satellite Imagery
+
+| Method | OA | MIoU | Kappa | PA | UA |
+|----------|-------|-------|-------|-------|-------|
+| Fmask | 85.21 | 71.52 | 63.01 | 86.24 | 70.38 |
+| L-unet | 90.56 | 77.95 | 68.79 | 79.32 | 78.94 |
+| FCN-8 | 90.88 | 78.84 | 71.32 | 76.28 | 82.31 |
+| MVGG-16 | 93.31 | 81.59 | 77.08 | 77.29 | 83.00 |
+| DPN | 93.40 | 86.34 | 81.52 | 84.61 | 89.93 |
+| DeeplabV2 | 94.11 | 86.90 | 81.63 | 84.93 | 89.87 |
+| PSPnet | 95.43 | 88.29 | 83.12 | 86.98 | 90.59 |
+| DeeplabV3 | 96.38 | 90.32 | 84.31 | 89.52 | 91.92 |
+| CDnet | 97.16 | 90.84 | 84.91 | 90.15 | 92.08 |
+
+## Citation
+
+```bibtex
+@ARTICLE{8681238,
+author={J. {Yang} and J. {Guo} and H. {Yue} and Z. {Liu} and H. {Hu} and K. {Li}},
+journal={IEEE Transactions on Geoscience and Remote Sensing},
+title={CDnet: CNN-Based Cloud Detection for Remote Sensing Imagery},
+year={2019},
+volume={57},
+number={8},
+pages={6195-6211}, doi={10.1109/TGRS.2019.2904868} }
+```
diff --git a/configs/model/cdnetv1/cdnetv1.yaml b/configs/model/cdnetv1/cdnetv1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f6d6b0e08e806029f3e50d0fb4fe3cc6915814c5
--- /dev/null
+++ b/configs/model/cdnetv1/cdnetv1.yaml
@@ -0,0 +1,19 @@
+_target_: src.models.base_module.BaseLitModule
+num_classes: 2
+
+net:
+ _target_: src.models.components.cdnetv1.CDnetV1
+ num_classes: 2
+
+criterion:
+ _target_: torch.nn.CrossEntropyLoss
+
+optimizer:
+ _target_: torch.optim.SGD
+ _partial_: true
+ lr: 0.0001
+
+scheduler: null
+
+# compile model for faster training with pytorch 2.0
+compile: false
diff --git a/configs/model/cdnetv2/README.md b/configs/model/cdnetv2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..45b3b2fc058643144a9f420a9a22157e763fa48d
--- /dev/null
+++ b/configs/model/cdnetv2/README.md
@@ -0,0 +1,90 @@
+# CDnetV2: CNN-Based Cloud Detection for Remote Sensing Imagery With Cloud-Snow Coexistence
+
+> [CDnetV2: CNN-Based Cloud Detection for Remote Sensing Imagery With Cloud-Snow Coexistence](https://ieeexplore.ieee.org/document/9094671)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+Cloud detection is a crucial preprocessing step for optical satellite remote sensing (RS) images. This article focuses on the cloud detection for RS imagery with cloud-snow coexistence and the utilization of the satellite thumbnails that lose considerable amount of high resolution and spectrum information of original RS images to extract cloud mask efficiently. To tackle this problem, we propose a novel cloud detection neural network with an encoder-decoder structure, named CDnetV2, as a series work on cloud detection. Compared with our previous CDnetV1, CDnetV2 contains two novel modules, that is, adaptive feature fusing model (AFFM) and high-level semantic information guidance flows (HSIGFs). AFFM is used to fuse multilevel feature maps by three submodules: channel attention fusion model (CAFM), spatial attention fusion model (SAFM), and channel attention refinement model (CARM). HSIGFs are designed to make feature layers at decoder of CDnetV2 be aware of the locations of the cloud objects. The high-level semantic information of HSIGFs is extracted by a proposed high-level feature fusing model (HFFM). By being equipped with these two proposed key modules, AFFM and HSIGFs, CDnetV2 is able to fully utilize features extracted from encoder layers and yield accurate cloud detection results. Experimental results on the ZY-3 satellite thumbnail data set demonstrate that the proposed CDnetV2 achieves accurate detection accuracy and outperforms several state-of-the-art methods.
+
+
+
+
+

+
+
+## Results and models
+
+### CLOUD EXTRACTION ACCURACY (%) OF DIFFERENT CNN-BASED METHODS ON ZY-3 SATELLITE THUMBNAILS
+
+
+| Method | OA | MIoU | Kappa | PA | UA |
+|-------|-------|-------|-------|-------|-------|
+| MSegNet | 90.86 | 81.20 | 75.57 | 73.78 | 86.13 |
+| MUnet | 91.62 | 82.51 | 76.70 | 74.44 | 87.39 |
+| PSPnet | 90.58 | 81.63 | 75.36 | 76.02 | 87.52 |
+| DeeplabV3+ | 91.80 | 82.62 | 77.65 | 75.30 | 87.76 |
+| CDnetV1 | 93.15 | 82.80 | 79.21 | 82.37 | 86.72 |
+| CDnetV2 | 95.76 | 86.62 | 82.51 | 87.75 | 88.58 |
+
+
+
+### STATISTICAL RESULTS OF CLOUDAGE ESTIMATION ERROR IN TERMS OF THE MAD AND ITS VARIANCE
+
+
+| Methods | Mean value ($\mu$) | Standard Deviation ($\sigma^2$)) |
+|------------|--------------------|----------------------------------|
+| CDnetV2 | 0.0241 | 0.0220 |
+| CDnetV1 | 0.0357 | 0.0288 |
+| DeeplabV3+ | 0.0456 | 0.0301 |
+| PSPnet | 0.0487 | 0.0380 |
+| MUnet | 0.0544 | 0.0583 |
+| MSegNet | 0.0572 | 0.0591 |
+
+
+
+
+### COMPUTATIONAL COMPLEXITY ANALYS IS OF DIFFERENT CNN-BASED METHODS
+
+| Methods | GFLOPs(224×224) | Trainable params | Running time (s)(1k×1k) |
+|------------|-----------------|------------------|-------------------------|
+| CDnetV2 | 31.5 | 65.9 M | 1.31 |
+| CDnetV1 | 48.5 | 64.8 M | 1.26 |
+| DeeplabV3+ | 31.8 | 40.3 M | 1.14 |
+| PSPnet | 19.3 | 46.6 M | 1.05 |
+| MUnet | 25.2 | 8.6 M | 1.09 |
+| MSegNet | 90.2 | 29.7 M | 1.28 |
+
+
+
+
+## Citation
+
+```bibtex
+@ARTICLE{8681238,
+author={J. {Yang} and J. {Guo} and H. {Yue} and Z. {Liu} and H. {Hu} and K. {Li}},
+journal={IEEE Transactions on Geoscience and Remote Sensing},
+title={CDnet: CNN-Based Cloud Detection for Remote Sensing Imagery},
+year={2019}, volume={57},
+number={8}, pages={6195-6211},
+doi={10.1109/TGRS.2019.2904868} }
+
+@ARTICLE{9094671,
+author={J. {Guo} and J. {Yang} and H. {Yue} and H. {Tan} and C. {Hou} and K. {Li}},
+journal={IEEE Transactions on Geoscience and Remote Sensing},
+title={CDnetV2: CNN-Based Cloud Detection for Remote Sensing Imagery With Cloud-Snow Coexistence},
+year={2021},
+volume={59},
+number={1},
+pages={700-713},
+doi={10.1109/TGRS.2020.2991398} }
+```
diff --git a/configs/model/cdnetv2/cdnetv2.yaml b/configs/model/cdnetv2/cdnetv2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..24fa0179f836ac8419181a78527a27ccf626c4d0
--- /dev/null
+++ b/configs/model/cdnetv2/cdnetv2.yaml
@@ -0,0 +1,19 @@
+_target_: src.models.cdnetv2_module.CDNetv2LitModule
+
+net:
+ _target_: src.models.components.cdnetv2.CDnetV2
+ num_classes: 2
+
+num_classes: 2
+
+criterion:
+ _target_: src.loss.cdnetv2_loss.CDnetv2Loss
+ loss_fn:
+ _target_: torch.nn.CrossEntropyLoss
+
+optimizer:
+ _target_: torch.optim.SGD
+ _partial_: true
+ lr: 0.0001
+
+scheduler: null
\ No newline at end of file
diff --git a/configs/model/cnn.yaml b/configs/model/cnn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..add5938cc4c4740056b8dbbf6edf32415dcf4ccc
--- /dev/null
+++ b/configs/model/cnn.yaml
@@ -0,0 +1,21 @@
+_target_: src.models.mnist_module.MNISTLitModule
+
+optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ lr: 0.001
+ weight_decay: 0.0
+
+scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ _partial_: true
+ mode: min
+ factor: 0.1
+ patience: 10
+
+net:
+ _target_: src.models.components.cnn.CNN
+ dim: 32
+
+# compile model for faster training with pytorch 2.0
+compile: false
diff --git a/configs/model/dbnet/README.md b/configs/model/dbnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a7dc71d248fd114b0801957965f2f08c86243b4
--- /dev/null
+++ b/configs/model/dbnet/README.md
@@ -0,0 +1,107 @@
+# Dual-Branch Network for Cloud and Cloud Shadow Segmentation
+
+
+> [Dual-Branch Network for Cloud and Cloud Shadow Segmentation](https://ieeexplore.ieee.org/document/9775689)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+— Cloud and cloud shadow segmentation is one of the
+most important issues in remote sensing image processing. Most
+of the remote sensing images are very complicated. In this work,
+a dual-branch model composed of transformer and convolution
+network is proposed to extract semantic and spatial detail information of the image, respectively, to solve the problems of false
+detection and missed detection. To improve the model’s feature
+extraction, a mutual guidance module (MGM) is introduced,
+so that the transformer branch and the convolution branch can
+guide each other for feature mining. Finally, in view of the problem of rough segmentation boundary, this work uses different
+features extracted by the transformer branch and the convolution
+branch for decoding and repairs the rough segmentation boundary in the decoding part to make the segmentation boundary
+clearer. Experimental results on the Landsat-8, Sentinel-2 data,
+the public dataset high-resolution cloud cover validation dataset
+created by researchers at Wuhan University (HRC_WHU), and
+the public dataset Spatial Procedures for Automated Removal
+of Cloud and Shadow (SPARCS) demonstrate the effectiveness
+of our method and its superiority to the existing state-of-the-art
+cloud and cloud shadow segmentation approaches.
+
+
+
+
+

+
+
+## Results and models
+
+### COMPARISON OF EVALUATION METRICS OF DIFFERENT MODELS ON CLOUD AND CLOUD SHADOW DATASET
+
+
+| **Method** | **Cloud** | | | | **Cloud Shadow** | | | | | | |
+|--------------------|-----------|----------|----------|-----------|------------------|----------|----------|-----------|-----------|------------|-------------|
+| | **OA(%)** | **P(%)** | **R(%)** | **F₁(%)** | **OA(%)** | **P(%)** | **R(%)** | **F₁(%)** | **PA(%)** | **MPA(%)** | **MIoU(%)** |
+| FCN-8S [7] | 95.87 | 88.47 | 92.8 | 90.63 | 97.19 | 86.87 | 88.72 | 87.79 | 93.40 | 90.52 | 84.01 |
+| PAN [38] | 98.25 | 96.29 | 95.49 | 95.89 | 98.31 | 92.71 | 92.46 | 92.59 | 96.73 | 95.52 | 91.26 |
+| BiseNet V2 [35] | 98.27 | 96.57 | 95.28 | 95.92 | 98.34 | 94.18 | 90.99 | 92.59 | 96.68 | 95.96 | 91.20 |
+| PSPNet [11] | 98.35 | 96.13 | 96.13 | 96.13 | 98.40 | 92.70 | 93.27 | 92.99 | 96.87 | 95.55 | 91.69 |
+| DeepLab V3Plus [9] | 98.65 | 97.70 | 95.94 | 96.82 | 98.66 | 93.88 | 94.32 | 94.10 | 97.37 | 96.48 | 92.99 |
+| LinkNet [39] | 98.61 | 96.59 | 96.91 | 96.75 | 98.54 | 94.19 | 92.91 | 93.55 | 97.23 | 96.24 | 92.55 |
+| ExtremeC3Net [40] | 98.64 | 97.32 | 96.28 | 96.80 | 98.60 | 94.68 | 92.95 | 93.82 | 97.30 | 96.57 | 92.76 |
+| DANet [41] | 96.45 | 91.68 | 91.72 | 91.71 | 97.29 | 88.40 | 87.68 | 88.04 | 94.03 | 91.93 | 85.07 |
+| CGNet [42] | 98.37 | 95.93 | 96.48 | 96.20 | 98.27 | 93.33 | 91.34 | 92.34 | 96.73 | 95.60 | 91.27 |
+| PVT [23] | 98.57 | 97.45 | 95.84 | 96.65 | 98.55 | 93.28 | 94.08 | 93.68 | 97.21 | 96.18 | 92.55 |
+| CvT [24] | 98.44 | 95.89 | 96.88 | 96.38 | 98.32 | 92.90 | 92.24 | 92.57 | 96.85 | 95.54 | 91.57 |
+| modified VGG [12] | 98.40 | 98.13 | 94.30 | 96.22 | 98.57 | 94.41 | 92.88 | 93.64 | 97.04 | 96.56 | 92.17 |
+| CloudNet [13] | 98.70 | 97.22 | 96.68 | 96.95 | 98.40 | 92.05 | 94.05 | 93.05 | 97.17 | 95.77 | 92.36 |
+| GAFFNet [43] | 98.53 | 96.49 | 96.63 | 96.56 | 98.41 | 92.71 | 93.40 | 93.05 | 97.06 | 95.73 | 92.08 |
+| Our | 98.76 | 97.95 | 96.22 | 97.08 | 98.73 | 94.39 | 94.39 | 94.39 | 97.56 | 96.77 | 93.42 |
+
+### COMPARISON OF EVALUATION METRICS OF DIFFERENT MODELS ON THE SPARCS DATASET
+
+
+| Method | Class Pixel Accuracy | | | | | Overall Results | | | | |
+|---|---|---|---|---|---|---|---|---|---|---|
+| | Cloud(%) | Cloud Shadow(%) | Snow/Ice(%) | Water(%) | Land(%) | PA(%) | Recall(%) | Precision(%) | F₁(%) | MIoU(%) |
+| PAN [38] | 89.10 | 75.27 | 86.60 | 79.96 | 95.64 | 91.20 | 87.34 | 85.32 | 81.53 | 76.57 |
+| BiSeNet V2 [35] | 85.87 | 64.75 | 93.84 | 81.44 | 97.17 | 91.31 | 89.77 | 84.61 | 83.09 | 77.79 |
+| PSPNet [11] | 90.79 | 63.75 | 94.22 | 77.73 | 96.84 | 91.78 | 90.29 | 84.67 | 83.48 | 78.20 |
+| DeepLab V3Plus [9] | 87.81 | 72.12 | 85.17 | 81.27 | 97.84 | 91.99 | 90.75 | 84.85 | 84.01 | 78.44 |
+| LinkNet [39] | 85.35 | 74.38 | 91.92 | 80.30 | 96.44 | 91.31 | 88.66 | 85.68 | 82.81 | 77.87 |
+| ExtremeC3Net [40] | 91.09 | 75.47 | 95.43 | 83.62 | 96.13 | 92.77 | 90.32 | 88.35 | 85.46 | 81.29 |
+| DANet [41] | 82.06 | 42.25 | 91.28 | 73.65 | 95.03 | 86.92 | 83.86 | 76.85 | 74.64 | 68.33 |
+| CGNet [42] | 90.63 | 72.78 | 95.37 | 83.30 | 96.51 | 93.22 | 91.00 | 88.95 | 86.30 | 82.28 |
+| PVT [23] | 88.22 | 75.77 | 92.00 | 86.27 | 95.92 | 92.02 | 89.76 | 87.64 | 84.66 | 80.24 |
+| CvT [24] | 88.24 | 71.63 | 95.41 | 87.71 | 96.14 | 92.17 | 89.83 | 87.83 | 84.80 | 80.55 |
+| modified VGG [12] | 85.55 | 58.53 | 94.87 | 79.35 | 95.98 | 89.99 | 86.38 | 82.85 | 79.36 | 74.00 |
+| CloudNet [13] | 85.99 | 74.58 | 91.78 | 80.34 | 96.52 | 91.50 | 88.49 | 85.84 | 82.79 | 77.95 |
+| GAFFNet [43] | 86.97 | 59.00 | 85.21 | 78.06 | 94.47 | 88.62 | 86.70 | 80.74 | 78.56 | 72.32 |
+| our | 91.12 | 78.38 | 96.59 | 89.99 | 97.52 | 94.31 | 92.90 | 90.72 | 88.83 | 85.26 |
+
+
+
+
+
+
+## Citation
+
+```bibtex
+@ARTICLE{9775689,
+ author={Lu, Chen and Xia, Min and Qian, Ming and Chen, Binyu},
+ journal={IEEE Transactions on Geoscience and Remote Sensing},
+ title={Dual-Branch Network for Cloud and Cloud Shadow Segmentation},
+ year={2022},
+ volume={60},
+ number={},
+ pages={1-12},
+ keywords={Feature extraction;Transformers;Convolution;Clouds;Image segmentation;Decoding;Task analysis;Deep learning;dual branch;remote sensing image;segmentation},
+ doi={10.1109/TGRS.2022.3175613}}
+
+```
diff --git a/configs/model/dbnet/dbnet.yaml b/configs/model/dbnet/dbnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f91ea3aea3a70e46dd69c02582312258164f33cd
--- /dev/null
+++ b/configs/model/dbnet/dbnet.yaml
@@ -0,0 +1,23 @@
+_target_: src.models.base_module.BaseLitModule
+
+net:
+ _target_: src.models.components.dbnet.DBNet
+ img_size: 256
+ in_channels: 3
+ num_classes: 2
+
+num_classes: 2
+
+criterion:
+ _target_: torch.nn.CrossEntropyLoss
+
+optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ weight_decay: 0.0001
+ lr: 0.0001
+
+scheduler:
+ _target_: torch.optim.lr_scheduler.StepLR
+ _partial_: true
+ step_size: 3
\ No newline at end of file
diff --git a/configs/model/hrcloudnet/README.md b/configs/model/hrcloudnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5fbadccde7b98620e39ea3afdf7b25e3e05eabdb
--- /dev/null
+++ b/configs/model/hrcloudnet/README.md
@@ -0,0 +1,102 @@
+# High-Resolution Cloud Detection Network
+
+> [High-Resolution Cloud Detection Network](https://arxiv.org/abs/2407.07365)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+The complexity of clouds, particularly in terms of texture
+detail at high resolutions, has not been well explored by most
+existing cloud detection networks. This paper introduces
+the High-Resolution Cloud Detection Network (HR-cloudNet), which utilizes a hierarchical high-resolution integration approach. HR-cloud-Net integrates a high-resolution
+representation module, layer-wise cascaded feature fusion
+module, and multi-resolution pyramid pooling module to
+effectively capture complex cloud features. This architecture
+preserves detailed cloud texture information while facilitating feature exchange across different resolutions, thereby
+enhancing overall performance in cloud detection. Additionally, a novel approach is introduced wherein a student
+view, trained on noisy augmented images, is supervised by a
+teacher view processing normal images. This setup enables
+the student to learn from cleaner supervisions provided by
+the teacher, leading to improved performance. Extensive
+evaluations on three optical satellite image cloud detection
+datasets validate the superior performance of HR-cloud-Net
+compared to existing methods.
+
+
+
+
+
+

+
+
+## Results and models
+
+### CHLandSat-8 dataset
+
+
+| method | mae | weight-F-measure | structure-measure |
+|--------------|------------|------------------|-------------------|
+| U-Net | 0.1130 | 0.7448 | 0.7228 |
+| PSPNet | 0.0969 | 0.7989 | 0.7672 |
+| SegNet | 0.1023 | 0.7780 | 0.7540 |
+| Cloud-Net | 0.1012 | 0.7641 | 0.7368 |
+| CDNet | 0.1286 | 0.7222 | 0.7087 |
+| CDNet-v2 | 0.1254 | 0.7350 | 0.7141 |
+| HRNet | **0.0737** | 0.8279 | **0.8141** |
+| GANet | 0.0751 | **0.8396** | 0.8106 |
+| HR-cloud-Net | **0.0628** | **0.8503** | **0.8337** |
+
+### 38-cloud dataset
+
+
+| method | mae | weight-F-measure | structure-measure |
+|--------------|------------|------------------|-------------------|
+| U-Net | 0.0638 | 0.7966 | 0.7845 |
+| PSPNet | 0.0653 | 0.7592 | 0.7766 |
+| SegNet | 0.0556 | 0.8002 | 0.8059 |
+| Cloud-Net | 0.0556 | 0.7615 | 0.7987 |
+| CDNet | 0.1057 | 0.7378 | 0.7270 |
+| CDNet-v2 | 0.1084 | 0.7183 | 0.7213 |
+| HRNet | 0.0538 | 0.8086 | 0.8183 |
+| GANet | **0.0410** | **0.8159** | **0.8342** |
+| HR-cloud-Net | **0.0395** | **0.8673** | **0.8479** |
+
+
+### SPARCS dataset
+
+
+| method | mae | weight-F-measure | structure-measure |
+|--------------|------------|------------------|-------------------|
+| U-Net | 0.1314 | 0.3651 | 0.5416 |
+| PSPNet | 0.1263 | 0.3758 | 0.5414 |
+| SegNet | 0.1100 | 0.4697 | 0.5918 |
+| Cloud-Net | 0.1213 | 0.3804 | 0.5536 |
+| CDNet | 0.1157 | 0.4585 | 0.5919 |
+| CDNet-v2 | 0.1219 | 0.4247 | 0.5704 |
+| HRNet | 0.1008 | 0.3742 | 0.5777 |
+| GANet | **0.0987** | **0.5134** | **0.6210** |
+| HR-cloud-Net | **0.0833** | **0.5202** | **0.6327** |
+
+
+
+
+## Citation
+
+```bibtex
+@InProceedings{LiJEI2024,
+ author = {Jingsheng Li and Tianxiang Xue and Jiayi Zhao and
+ Jingmin Ge and Yufang Min and Wei Su and Kun Zhan},
+ title = {High-Resolution Cloud Detection Network},
+ booktitle = {Journal of Electronic Imaging},
+ year = {2024},
+}
+```
\ No newline at end of file
diff --git a/configs/model/hrcloudnet/hrcloudnet.yaml b/configs/model/hrcloudnet/hrcloudnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..848b9833e9ddba1293effebd73ed1b3adf9faa04
--- /dev/null
+++ b/configs/model/hrcloudnet/hrcloudnet.yaml
@@ -0,0 +1,18 @@
+_target_: src.models.base_module.BaseLitModule
+
+net:
+ _target_: src.models.components.hrcloudnet.HRCloudNet
+ num_classes: 2
+
+num_classes: 2
+
+criterion:
+ _target_: torch.nn.CrossEntropyLoss
+
+optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ lr: 0.00005
+ weight_decay: 0.0005
+
+scheduler: null
\ No newline at end of file
diff --git a/configs/model/lnn.yaml b/configs/model/lnn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d02662f8cdd442ee987695d83b33fb2ed4178142
--- /dev/null
+++ b/configs/model/lnn.yaml
@@ -0,0 +1,21 @@
+_target_: src.models.mnist_module.MNISTLitModule
+
+optimizer:
+ _target_: torch.optim.Adam
+ _partial_: true
+ lr: 0.001
+ weight_decay: 0.0
+
+scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ _partial_: true
+ mode: min
+ factor: 0.1
+ patience: 10
+
+net:
+ _target_: src.models.components.lnn.LNN
+ dim: 32
+
+# compile model for faster training with pytorch 2.0
+compile: false
diff --git a/configs/model/mcdnet/README.md b/configs/model/mcdnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ea6597541e4afa90a2621807c61164b6bfc86b94
--- /dev/null
+++ b/configs/model/mcdnet/README.md
@@ -0,0 +1,115 @@
+# MCDNet: Multilevel cloud detection network for remote sensing images based on dual-perspective change-guided and multi-scale feature fusion
+
+> [MCDNet: Multilevel cloud detection network for remote sensing images based on dual-perspective change-guided and multi-scale feature fusion](https://www.sciencedirect.com/science/article/pii/S1569843224001742?via%3Dihub)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+Cloud detection plays a crucial role in the preprocessing of optical remote sensing images. While extensive deep learning-based methods have shown strong performance in detecting thick clouds, their ability to identify thin and broken clouds is often inadequate due to their sparse distribution, semi-transparency, and similarity to background regions. To address this limitation, we introduce a multilevel cloud detection network (MCDNet) capable of simultaneously detecting thick and thin clouds. This network effectively enhances the accuracy of identifying thin and broken clouds by integrating a dual-perspective change-guided mechanism (DPCG) and a multi-scale feature fusion module (MSFF). The DPCG creates a dual-input stream by combining the original image with the thin cloud removal image, and then utilizes a dual-perspective feature fusion module (DPFF) to perform feature fusion and extract change features, thereby improving the model's ability to perceive thin cloud regions and mitigate inter-class similarity in multilevel cloud detection. The MSFF enhances the model's sensitivity to broken clouds by utilizing multiple non-adjacent low-level features to remedy the missing spatial information in the high-level features during multiple downsampling. Experimental results on the L8-Biome and WHUS2-CD datasets demonstrate that MCDNet significantly enhances the detection performance of both thin and broken clouds, and outperforms state-of-the-art methods in accuracy and efficiency.
+
+
+
+
+

+
+
+
+## Results and models
+
+### Quantitative performance of MCDNet with different thin cloud removal methods on the L8-Biome dataset.
+
+
+| Thin cloud removal method | OA | IoU | Specificity | Kappa |
+|--------------------------|------|------|-------------|--------|
+| DCP | 94.18| 89.55| 97.09 | 70.42 |
+| HF | 94.22| 89.69| 97.11 | 70.59 |
+| BCCR | 94.25| 89.75| 97.13 | 71.01 |
+
+### Quantitative performance of MCDNet with different components using BCCR thin cloud removal method on the L8-Biome dataset
+
+
+| Baseline | √ | √ | √ | √ |
+|----------------|-------|-------|-------|-------|
+| MSFF | × | √ | × | √ |
+| DPFF | × | × | √ | √ |
+| OA | 92.91 | 93.25 | 93.81 | 94.25 |
+| IoU | 87.58 | 88.13 | 88.99 | 89.75 |
+| Specificity | 96.46 | 96.62 | 96.90 | 97.13 |
+| Kappa | 67.22 | 68.88 | 69.94 | 71.01 |
+| Parameters (M) | 8.89 | 10.67 | 11.34 | 13.11 |
+
+
+
+### Quantitative comparisons of different methods on the L8-Biome dataset.
+
+
+| Method | OA | IoU | Specificity | Kappa |
+|----------|-------|-------|-------------|-------|
+| UNet | 90.78 | 84.44 | 95.39 | 63.22 |
+| SegNet | 91.44 | 85.33 | 95.72 | 63.98 |
+| HRNet | 91.17 | 84.87 | 95.59 | 62.32 |
+| SwinUnet | 91.73 | 85.74 | 95.86 | 64.36 |
+| MFCNN | 92.39 | 86.76 | 96.20 | 66.72 |
+| MSCFF | 92.49 | 86.99 | 96.25 | 65.82 |
+| CDNetV2 | 91.55 | 85.51 | 95.77 | 64.91 |
+| CloudNet | 92.19 | 86.55 | 96.09 | 65.99 |
+| MCDNet | 94.25 | 89.75 | 97.13 | 71.01 |
+
+
+### Quantitative comparisons of different methods on the WHUS2-CD dataset.
+
+| Method | OA | IoU | Specificity | Kappa |
+|----------|-------|-------|-------------|-------|
+| UNet | 98.24 | 63.72 | 99.51 | 63.45 |
+| SegNet | 98.09 | 62.53 | 99.15 | 61.71 |
+| HRNet | 97.91 | 61.73 | 99.41 | 59.29 |
+| SwinUnet | 98.69 | 61.93 | 99.48 | 67.81 |
+| MFCNN | 98.56 | 64.79 | 98.92 | 65.52 |
+| MSCFF | 98.84 | 66.21 | 99.59 | 68.73 |
+| CDNetV2 | 98.68 | 65.69 | 99.54 | 67.21 |
+| CloudNet | 98.58 | 65.82 | 99.55 | 68.57 |
+| MCDNet | 98.97 | 66.45 | 99.58 | 69.42 |
+
+
+### Quantitative performance of different fusion schemes for extracting change features on L8-Biome dataset.
+
+| Methods | Fusion schemes | OA | IoU | Specificity | Kappa | Params (M) |
+|---------|----------------|-------|-------|-------------|-------|------------|
+| (a) | no fusion | 93.25 | 88.13 | 96.62 | 68.88 | 10.67 |
+| (b) | Concatenation | 94.14 | 89.62 | 97.07 | 66.49 | 11.36 |
+| (c) | Subtraction | 91.05 | 84.85 | 95.52 | 64.32 | 10.67 |
+| (d) | CDM | 93.77 | 89.02 | 96.88 | 65.46 | 12.07 |
+| (e) | DPFF | 94.25 | 89.75 | 97.13 | 71.01 | 13.11 |
+
+
+### Quantitative performance of different multi-scale feature fusion schemes on L8-Biome dataset.
+
+| Methods | Fusion schemes | OA | IoU | Specificity | Kappa | Params (M) |
+|-----------|----------------|------|------|-------------|--------|------------|
+| (a) | no fusion | 93.81| 88.99| 96.90 | 69.94 | 11.34 |
+| (b) | HRNet | 94.03| 89.47| 97.02 | 69.59 | 12.58 |
+| (c) | MSCFF | 93.45| 88.45| 96.72 | 68.87 | 14.86 |
+| (d) | CloudNet | 93.92| 89.29| 96.96 | 66.97 | 11.34 |
+| (e) | MSFF | 94.25| 89.75| 97.13 | 71.01 | 13.11 |
+
+## Citation
+
+```bibtex
+@article{MCDNet,
+title = {MCDNet: Multilevel cloud detection network for remote sensing images based on dual-perspective change-guided and multi-scale feature fusion},
+journal = {International Journal of Applied Earth Observation and Geoinformation},
+volume = {129},
+pages = {103820},
+year = {2024},
+issn = {1569-8432},
+doi = {10.1016/j.jag.2024.103820}
+}
+```
diff --git a/configs/model/mcdnet/mcdnet.yaml b/configs/model/mcdnet/mcdnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fe0faa2092a56e0abea70b49cef73921c92d286f
--- /dev/null
+++ b/configs/model/mcdnet/mcdnet.yaml
@@ -0,0 +1,22 @@
+_target_: src.models.base_module.BaseLitModule
+
+net:
+ _target_: src.models.components.mcdnet.MCDNet
+ in_channels: 3
+ num_classes: 2
+
+num_classes: 2
+
+criterion:
+ _target_: torch.nn.CrossEntropyLoss
+
+optimizer:
+ _target_: torch.optim.SGD
+ _partial_: true
+ lr: 0.0001
+scheduler: null
+
+#scheduler:
+# _target_: torch.optim.lr_scheduler.LambdaLR
+# _partial_: true
+# lr_lambda: src.models.components.mcdnet.lr_lambda
\ No newline at end of file
diff --git a/configs/model/scnn/README.md b/configs/model/scnn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef8f2f66e2b956f985ae132fb238be15f789d36b
--- /dev/null
+++ b/configs/model/scnn/README.md
@@ -0,0 +1,109 @@
+# Remote sensing image cloud detection using a shallow convolutional neural network
+
+> [Remote sensing image cloud detection using a shallow convolutional neural network](https://www.sciencedirect.com/science/article/abs/pii/S0924271624000352?via%3Dihub#fn1)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+The state-of-the-art methods for cloud detection are dominated by deep convolutional neural networks (DCNNs). However, it is very expensive to train DCNNs for cloud detection and the trained DCNNs do not always perform well as expected. This paper proposes a shallow CNN (SCNN) by removing pooling/unpooling layers and normalization layers in DCNNs, retaining only three convolutional layers, and equipping them with 3 filters of
+1
+×
+1
+,
+1
+×
+1
+,
+3
+×
+3
+ in spatial dimensions. It demonstrates that the three convolutional layers are sufficient for cloud detection. Since the label output by the SCNN for a pixel depends on a 3 × 3 patch around this pixel, the SCNN can be trained using some thousands 3 × 3 patches together with ground truth of their center pixels. It is very cheap to train a SCNN using some thousands 3 × 3 patches and to provide ground truth of their center pixels. Despite of its low cost, SCNN training is stabler than DCNN training, and the trained SCNN outperforms the representative state-of-the-art DCNNs for cloud detection. The same resolution of original image, feature maps and final label map assures that details are not lost as by pooling/unpooling in DCNNs. The border artifacts suffering from deep convolutional and pooling/unpooling layers are minimized by 3 convolutional layers with
+1
+×
+1
+,
+1
+×
+1
+,
+3
+×
+3
+ filters. Incoherent patches suffering from patch-by-patch segmentation and batch normalization are eliminated by SCNN without normalization layers. Extensive experiments based on the L7 Irish, L8 Biome and GF1 WHU datasets are carried out to evaluate the proposed method and compare with state-of-the-art methods. The proposed SCNN promises to deal with images from any other sensors.
+
+
+## Results and models
+
+###
+
+| Dataset | Zone/Biome | Clear | | | | Cloud | | |
+|----------|------------|-------|-------|-------|-------|-------|-------|-------|
+| | | A^O | A^P | A^U | F_1 | A^P | A^U | F_1 |
+| L7 Irish | Austral | 88.87 | 86.94 | 76.29 | 81.27 | 89.61 | 94.69 | 92.08 |
+| | Boreal | 97.06 | 98.78 | 96.53 | 97.64 | 94.31 | 97.96 | 96.10 |
+| | Mid.N. | 93.90 | 97.04 | 89.02 | 92.86 | 91.74 | 97.82 | 94.68 |
+| | Mid.S. | 97.19 | 98.96 | 95.94 | 97.43 | 95.13 | 98.75 | 96.91 |
+| | Polar.N. | 83.55 | 81.99 | 84.55 | 83.25 | 85.11 | 82.61 | 83.84 |
+| | SubT.N. | 98.16 | 99.44 | 98.65 | 99.04 | 69.09 | 84.38 | 75.97 |
+| | SubT.S. | 95.35 | 99.02 | 95.10 | 97.02 | 83.51 | 96.34 | 89.47 |
+| | Tropical | 89.72 | 86.48 | 42.76 | 57.23 | 90.00 | 98.72 | 94.16 |
+| | Mean | 94.13 | 97.17 | 93.73 | 95.42 | 88.97 | 94.88 | 91.83 |
+| L8 Biome | Barren | 93.46 | 90.81 | 97.89 | 94.22 | 97.23 | 88.19 | 92.49 |
+| | Forest | 96.65 | 90.65 | 88.18 | 89.40 | 97.76 | 98.26 | 98.01 |
+| | Grass | 94.40 | 93.75 | 99.51 | 96.54 | 97.67 | 75.56 | 85.20 |
+| | Shrubland | 99.24 | 98.69 | 99.83 | 99.26 | 99.83 | 98.64 | 99.23 |
+| | Snow | 86.50 | 78.41 | 91.65 | 84.51 | 93.67 | 83.04 | 88.03 |
+| | Urban | 95.29 | 88.87 | 98.29 | 93.34 | 99.08 | 93.76 | 96.35 |
+| | Wetlands | 93.72 | 93.90 | 98.36 | 96.07 | 92.92 | 77.14 | 84.30 |
+| | Water | 97.52 | 96.65 | 99.29 | 97.95 | 98.90 | 94.92 | 96.87 |
+| | Mean | 94.35 | 91.75 | 97.75 | 94.65 | 97.47 | 90.79 | 94.01 |
+| GF1 WHU | Mean | 94.46 | 92.07 | 97.62 | 94.76 | 97.31 | 91.11 | 94.11 |
+
+
+### Quantitative evaluation for SCNN on L7 Irish, L8 Biome and GF1 WHU datasets
+
+| Dataset | Zone/Biome | Clear | | | | Cloud | | |
+|----------|------------|-------|-------|-------|-------|-------|-------|-------|
+| | | A^O | A^P | A^U | F_1 | A^P | A^U | F_1 |
+| L7 Irish | Austral | 88.87 | 86.94 | 76.29 | 81.27 | 89.61 | 94.69 | 92.08 |
+| | Boreal | 97.06 | 98.78 | 96.53 | 97.64 | 94.31 | 97.96 | 96.10 |
+| | Mid.N. | 93.90 | 97.04 | 89.02 | 92.86 | 91.74 | 97.82 | 94.68 |
+| | Mid.S. | 97.19 | 98.96 | 95.94 | 97.43 | 95.13 | 98.75 | 96.91 |
+| | Polar.N. | 83.55 | 81.99 | 84.55 | 83.25 | 85.11 | 82.61 | 83.84 |
+| | SubT.N. | 98.16 | 99.44 | 98.65 | 99.04 | 69.09 | 84.38 | 75.97 |
+| | SubT.S. | 95.35 | 99.02 | 95.10 | 97.02 | 83.51 | 96.34 | 89.47 |
+| | Tropical | 89.72 | 86.48 | 42.76 | 57.23 | 90.00 | 98.72 | 94.16 |
+| | Mean | 94.13 | 97.17 | 93.73 | 95.42 | 88.97 | 94.88 | 91.83 |
+| L8 Biome | Barren | 93.46 | 90.81 | 97.89 | 94.22 | 97.23 | 88.19 | 92.49 |
+| | Forest | 96.65 | 90.65 | 88.18 | 89.40 | 97.76 | 98.26 | 98.01 |
+| | Grass | 94.40 | 93.75 | 99.51 | 96.54 | 97.67 | 75.56 | 85.20 |
+| | Shrubland | 99.24 | 98.69 | 99.83 | 99.26 | 99.83 | 98.64 | 99.23 |
+| | Snow | 86.50 | 78.41 | 91.65 | 84.51 | 93.67 | 83.04 | 88.03 |
+| | Urban | 95.29 | 88.87 | 98.29 | 93.34 | 99.08 | 93.76 | 96.35 |
+| | Wetlands | 93.72 | 93.90 | 98.36 | 96.07 | 92.92 | 77.14 | 84.30 |
+| | Water | 97.52 | 96.65 | 99.29 | 97.95 | 98.90 | 94.92 | 96.87 |
+| | Mean | 94.35 | 91.75 | 97.75 | 94.65 | 97.47 | 90.79 | 94.01 |
+| GF1 WHU | Mean | 94.46 | 92.07 | 97.62 | 94.76 | 97.31 | 91.11 | 94.11 |
+
+
+
+## Citation
+
+```bibtex
+@InProceedings{LiJEI2024,
+ author = {Dengfeng Chai and Jingfeng Huang and Minghui Wu and
+ Xiaoping Yang and Ruisheng Wang},
+ title = {Remote sensing image cloud detection using a shallow convolutional neural network},
+ booktitle = {ISPRS Journal of Photogrammetry and Remote Sensing},
+ year = {2024},
+}
+```
\ No newline at end of file
diff --git a/configs/model/scnn/scnn.yaml b/configs/model/scnn/scnn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4246f8fa792f88b9a20a02b8ae6b258419b13959
--- /dev/null
+++ b/configs/model/scnn/scnn.yaml
@@ -0,0 +1,17 @@
+_target_: src.models.base_module.BaseLitModule
+
+net:
+ _target_: src.models.components.scnn.SCNN
+ num_classes: 2
+
+num_classes: 2
+
+criterion:
+ _target_: torch.nn.CrossEntropyLoss
+
+optimizer:
+ _target_: torch.optim.RMSprop
+ _partial_: true
+ lr: 0.0001
+
+scheduler: null
\ No newline at end of file
diff --git a/configs/model/unet/README.md b/configs/model/unet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee75a24e7068dd6cfbf0fcb9c5d1b76edfb834bf
--- /dev/null
+++ b/configs/model/unet/README.md
@@ -0,0 +1,92 @@
+# UNet
+
+> [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at [this http URL](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/).
+
+
+
+
+

+
+
+## Results and models
+
+### Cityscapes
+
+| Method | Backbone | Loss | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
+| ---------- | ----------- | ------------- | --------- | ------: | -------- | -------------- | ------ | ----: | ------------: | ------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy | 512x1024 | 160000 | 17.91 | 3.05 | V100 | 69.10 | 71.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes_20211210_145204-6860854e.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes_20211210_145204.log.json) |
+
+### DRIVE
+
+| Method | Backbone | Loss | Image Size | Crop Size | Stride | Lr schd | Mem (GB) | Inf time (fps) | Device | mDice | Dice | config | download |
+| ---------------- | ----------- | -------------------- | ---------- | --------- | -----: | ------- | -------- | -------------: | ------ | ----: | ----: | ------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy | 584x565 | 64x64 | 42x42 | 40000 | 0.680 | - | V100 | 88.38 | 78.67 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-40k_drive-64x64.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_64x64_40k_drive/fcn_unet_s5-d16_64x64_40k_drive_20201223_191051-5daf6d3b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_64x64_40k_drive/unet_s5-d16_64x64_40k_drive-20201223_191051.log.json) |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy + Dice | 584x565 | 64x64 | 42x42 | 40000 | 0.582 | - | V100 | 88.71 | 79.32 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-ce-1.0-dice-3.0-40k_drive-64x64.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive_20211210_201820-785de5c2.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive/fcn_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive_20211210_201820.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy | 584x565 | 64x64 | 42x42 | 40000 | 0.599 | - | V100 | 88.35 | 78.62 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-40k_drive-64x64.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_64x64_40k_drive/pspnet_unet_s5-d16_64x64_40k_drive_20201227_181818-aac73387.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_64x64_40k_drive/pspnet_unet_s5-d16_64x64_40k_drive-20201227_181818.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy + Dice | 584x565 | 64x64 | 42x42 | 40000 | 0.585 | - | V100 | 88.76 | 79.42 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-ce-1.0-dice-3.0-40k_drive-64x64.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive_20211210_201821-22b3e3ba.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive/pspnet_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive_20211210_201821.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy | 584x565 | 64x64 | 42x42 | 40000 | 0.596 | - | V100 | 88.38 | 78.69 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-40k_drive-64x64.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_64x64_40k_drive/deeplabv3_unet_s5-d16_64x64_40k_drive_20201226_094047-0671ff20.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_64x64_40k_drive/deeplabv3_unet_s5-d16_64x64_40k_drive-20201226_094047.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy + Dice | 584x565 | 64x64 | 42x42 | 40000 | 0.582 | - | V100 | 88.84 | 79.56 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-ce-1.0-dice-3.0-40k_drive-64x64.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive_20211210_201825-6bf0efd7.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_64x64_40k_drive_20211210_201825.log.json) |
+
+### STARE
+
+| Method | Backbone | Loss | Image Size | Crop Size | Stride | Lr schd | Mem (GB) | Inf time (fps) | Device | mDice | Dice | config | download |
+| ---------------- | ----------- | -------------------- | ---------- | --------- | -----: | ------- | -------- | -------------: | ------ | ----: | ----: | --------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy | 605x700 | 128x128 | 85x85 | 40000 | 0.968 | - | V100 | 89.78 | 81.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-40k_stare-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_128x128_40k_stare/fcn_unet_s5-d16_128x128_40k_stare_20201223_191051-7d77e78b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_128x128_40k_stare/unet_s5-d16_128x128_40k_stare-20201223_191051.log.json) |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy + Dice | 605x700 | 128x128 | 85x85 | 40000 | 0.986 | - | V100 | 90.65 | 82.70 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-ce-1.0-dice-3.0-40k_stare-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare_20211210_201821-f75705a9.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare_20211210_201821.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy | 605x700 | 128x128 | 85x85 | 40000 | 0.982 | - | V100 | 89.89 | 81.22 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-40k_stare-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_128x128_40k_stare/pspnet_unet_s5-d16_128x128_40k_stare_20201227_181818-3c2923c4.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_128x128_40k_stare/pspnet_unet_s5-d16_128x128_40k_stare-20201227_181818.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy + Dice | 605x700 | 128x128 | 85x85 | 40000 | 1.028 | - | V100 | 90.72 | 82.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-ce-1.0-dice-3.0-40k_stare-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare_20211210_201823-f1063ef7.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare_20211210_201823.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy | 605x700 | 128x128 | 85x85 | 40000 | 0.999 | - | V100 | 89.73 | 80.93 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-40k_stare-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_128x128_40k_stare/deeplabv3_unet_s5-d16_128x128_40k_stare_20201226_094047-93dcb93c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_128x128_40k_stare/deeplabv3_unet_s5-d16_128x128_40k_stare-20201226_094047.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy + Dice | 605x700 | 128x128 | 85x85 | 40000 | 1.010 | - | V100 | 90.65 | 82.71 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-ce-1.0-dice-3.0-40k_stare-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare_20211210_201825-21db614c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_stare_20211210_201825.log.json) |
+
+### CHASE_DB1
+
+| Method | Backbone | Loss | Image Size | Crop Size | Stride | Lr schd | Mem (GB) | Inf time (fps) | Device | mDice | Dice | config | download |
+| ---------------- | ----------- | -------------------- | ---------- | --------- | -----: | ------- | -------- | -------------: | ------ | ----: | ----: | ------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy | 960x999 | 128x128 | 85x85 | 40000 | 0.968 | - | V100 | 89.46 | 80.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-40k_chase-db1-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_128x128_40k_chase_db1/fcn_unet_s5-d16_128x128_40k_chase_db1_20201223_191051-11543527.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_128x128_40k_chase_db1/unet_s5-d16_128x128_40k_chase_db1-20201223_191051.log.json) |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy + Dice | 960x999 | 128x128 | 85x85 | 40000 | 0.986 | - | V100 | 89.52 | 80.40 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-ce-1.0-dice-3.0-40k_chase-db1-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1_20211210_201821-1c4eb7cf.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1/fcn_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1_20211210_201821.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy | 960x999 | 128x128 | 85x85 | 40000 | 0.982 | - | V100 | 89.52 | 80.36 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-40k_chase-db1-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_128x128_40k_chase_db1/pspnet_unet_s5-d16_128x128_40k_chase_db1_20201227_181818-68d4e609.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_128x128_40k_chase_db1/pspnet_unet_s5-d16_128x128_40k_chase_db1-20201227_181818.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy + Dice | 960x999 | 128x128 | 85x85 | 40000 | 1.028 | - | V100 | 89.45 | 80.28 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-ce-1.0-dice-3.0-40k_chase-db1-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1_20211210_201823-c0802c4d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1/pspnet_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1_20211210_201823.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy | 960x999 | 128x128 | 85x85 | 40000 | 0.999 | - | V100 | 89.57 | 80.47 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet_s5-d16_deeplabv3_4xb4-40k_chase-db1-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_128x128_40k_chase_db1/deeplabv3_unet_s5-d16_128x128_40k_chase_db1_20201226_094047-4c5aefa3.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_128x128_40k_chase_db1/deeplabv3_unet_s5-d16_128x128_40k_chase_db1-20201226_094047.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy + Dice | 960x999 | 128x128 | 85x85 | 40000 | 1.010 | - | V100 | 89.49 | 80.37 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-ce-1.0-dice-3.0-40k_chase-db1-128x128.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1_20211210_201825-4ef29df5.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_128x128_40k_chase-db1_20211210_201825.log.json) |
+
+### HRF
+
+| Method | Backbone | Loss | Image Size | Crop Size | Stride | Lr schd | Mem (GB) | Inf time (fps) | Device | mDice | Dice | config | download |
+| ---------------- | ----------- | -------------------- | ---------- | --------- | ------: | ------- | -------- | -------------: | ------ | ----: | ----: | ------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy | 2336x3504 | 256x256 | 170x170 | 40000 | 2.525 | - | V100 | 88.92 | 79.45 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-40k_hrf-256x256.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_256x256_40k_hrf/fcn_unet_s5-d16_256x256_40k_hrf_20201223_173724-d89cf1ed.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_256x256_40k_hrf/unet_s5-d16_256x256_40k_hrf-20201223_173724.log.json) |
+| UNet + FCN | UNet-S5-D16 | Cross Entropy + Dice | 2336x3504 | 256x256 | 170x170 | 40000 | 2.623 | - | V100 | 89.64 | 80.87 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_fcn_4xb4-ce-1.0-dice-3.0-40k_hrf-256x256.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf_20211210_201821-c314da8a.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf/fcn_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf_20211210_201821.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy | 2336x3504 | 256x256 | 170x170 | 40000 | 2.588 | - | V100 | 89.24 | 80.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-40k_hrf-256x256.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_256x256_40k_hrf/pspnet_unet_s5-d16_256x256_40k_hrf_20201227_181818-fdb7e29b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_256x256_40k_hrf/pspnet_unet_s5-d16_256x256_40k_hrf-20201227_181818.log.json) |
+| UNet + PSPNet | UNet-S5-D16 | Cross Entropy + Dice | 2336x3504 | 256x256 | 170x170 | 40000 | 2.798 | - | V100 | 89.69 | 80.96 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_pspnet_4xb4-ce-1.0-dice-3.0-40k_hrf-256x256.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf_20211210_201823-53d492fa.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf/pspnet_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf_20211210_201823.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy | 2336x3504 | 256x256 | 170x170 | 40000 | 2.604 | - | V100 | 89.32 | 80.21 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-40k_hrf-256x256.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_256x256_40k_hrf/deeplabv3_unet_s5-d16_256x256_40k_hrf_20201226_094047-3a1fdf85.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_256x256_40k_hrf/deeplabv3_unet_s5-d16_256x256_40k_hrf-20201226_094047.log.json) |
+| UNet + DeepLabV3 | UNet-S5-D16 | Cross Entropy + Dice | 2336x3504 | 256x256 | 170x170 | 40000 | 2.607 | - | V100 | 89.56 | 80.71 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/unet/unet-s5-d16_deeplabv3_4xb4-ce-1.0-dice-3.0-40k_hrf-256x256.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf_20211210_202032-59daf7a4.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf/deeplabv3_unet_s5-d16_ce-1.0-dice-3.0_256x256_40k_hrf_20211210_202032.log.json) |
+
+Note:
+
+- In `DRIVE`, `STARE`, `CHASE_DB1`, and `HRF` dataset, `mDice` is mean dice of background and vessel, while `Dice` is dice metric of vessel(foreground) only.
+
+## Citation
+
+```bibtex
+@inproceedings{ronneberger2015u,
+ title={U-net: Convolutional networks for biomedical image segmentation},
+ author={Ronneberger, Olaf and Fischer, Philipp and Brox, Thomas},
+ booktitle={International Conference on Medical image computing and computer-assisted intervention},
+ pages={234--241},
+ year={2015},
+ organization={Springer}
+}
+```
\ No newline at end of file
diff --git a/configs/model/unet/unet.yaml b/configs/model/unet/unet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..91e70a9bdfad1102b183d69b0fb4b7b2e9960d3b
--- /dev/null
+++ b/configs/model/unet/unet.yaml
@@ -0,0 +1,57 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /trainer: gpu
+ - override /data: hrcWhu
+ - override /model: unet
+ - override /logger: wandb
+ - override /callbacks: default
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+tags: ["hrcWhu", "unet"]
+
+seed: 42
+
+trainer:
+ min_epochs: 10
+ max_epochs: 10
+ gradient_clip_val: 0.5
+ devices: 1
+
+data:
+ batch_size: 128
+ train_val_test_split: [55_000, 5_000, 10_000]
+ num_workers: 31
+ pin_memory: False
+ persistent_workers: False
+
+model:
+ in_channels: 3
+ out_channels: 7
+
+
+logger:
+ wandb:
+ project: "hrcWhu"
+ name: "unet"
+ aim:
+ experiment: "unet"
+
+callbacks:
+ model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: True
+ auto_insert_metric_name: False
+
+ early_stopping:
+ monitor: "val/loss"
+ patience: 100
+ mode: "min"
\ No newline at end of file
diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec81db2d34712909a79be3e42e65efe08c35ecee
--- /dev/null
+++ b/configs/paths/default.yaml
@@ -0,0 +1,18 @@
+# path to root directory
+# this requires PROJECT_ROOT environment variable to exist
+# you can replace it with "." if you want the root to be the current working directory
+root_dir: ${oc.env:PROJECT_ROOT}
+
+# path to data directory
+data_dir: ${paths.root_dir}/data/
+
+# path to logging directory
+log_dir: ${paths.root_dir}/logs/
+
+# path to output directory, created dynamically by hydra
+# path generation pattern is specified in `configs/hydra/default.yaml`
+# use it to store all files generated during the run, like ckpts and metrics
+output_dir: ${hydra:runtime.output_dir}
+
+# path to working directory
+work_dir: ${hydra:runtime.cwd}
diff --git a/configs/train.yaml b/configs/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e8d3e515725d75347d8b2f6aa4e0b80fa68fd82
--- /dev/null
+++ b/configs/train.yaml
@@ -0,0 +1,49 @@
+# @package _global_
+
+# specify here default configuration
+# order of defaults determines the order in which configs override each other
+defaults:
+ - _self_
+ - data: null
+ - model: null
+ - callbacks: default
+ - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
+ - trainer: default
+ - paths: default
+ - extras: default
+ - hydra: default
+
+ # experiment configs allow for version control of specific hyperparameters
+ # e.g. best hyperparameters for given model and datamodule
+ - experiment: null
+
+ # config for hyperparameter optimization
+ - hparams_search: null
+
+ # optional local config for machine/user specific settings
+ # it's optional since it doesn't need to exist and is excluded from version control
+ - optional local: default
+
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
+ - debug: null
+
+# task name, determines output directory path
+task_name: "train"
+
+# tags to help you identify your experiments
+# you can overwrite this in experiment configs
+# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
+tags: ["dev"]
+
+# set False to skip model training
+train: True
+
+# evaluate on test set, using best model weights achieved during training
+# lightning chooses best weights based on the metric specified in checkpoint callback
+test: True
+
+# simply provide checkpoint path to resume training
+ckpt_path: null
+
+# seed for random number generators in pytorch, numpy and python.random
+seed: null
diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b7d6767e60c956567555980654f15e7bb673a41f
--- /dev/null
+++ b/configs/trainer/cpu.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default
+
+accelerator: cpu
+devices: 1
diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ab8f89004c399a33440f014fa27e040d4e952bc2
--- /dev/null
+++ b/configs/trainer/ddp.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - default
+
+strategy: ddp
+
+accelerator: gpu
+devices: 4
+num_nodes: 1
+sync_batchnorm: True
diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8404419e5c295654967d0dfb73a7366e75be2f1f
--- /dev/null
+++ b/configs/trainer/ddp_sim.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - default
+
+# simulate DDP on CPU, useful for debugging
+accelerator: cpu
+devices: 2
+strategy: ddp_spawn
diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..50905e7fdf158999e7c726edfff1a4dc16d548da
--- /dev/null
+++ b/configs/trainer/default.yaml
@@ -0,0 +1,19 @@
+_target_: lightning.pytorch.trainer.Trainer
+
+default_root_dir: ${paths.output_dir}
+
+min_epochs: 1 # prevents early stopping
+max_epochs: 10
+
+accelerator: cpu
+devices: 1
+
+# mixed precision for extra speed-up
+# precision: 16
+
+# perform a validation loop every N training epochs
+check_val_every_n_epoch: 1
+
+# set True to to ensure deterministic results
+# makes training slower but gives more reproducibility than just setting seeds
+deterministic: False
diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..54d2c957bd106cf39352440c185ff4976b16e899
--- /dev/null
+++ b/configs/trainer/gpu.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - default
+
+accelerator: gpu
+devices: 1
+min_epochs: 10
+max_epochs: 10000
diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ecf6d5cc3a34ca127c5510f4a18e989561e38e4
--- /dev/null
+++ b/configs/trainer/mps.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default
+
+accelerator: mps
+devices: 1
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..76cbd9be60c9cb15921bd94d0dae0c8573c4c507
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,27 @@
+name: cloudseg
+
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+
+dependencies:
+ - python=3.8.0
+ - pytorch=2.0.0
+ - torchvision=0.15.0
+ - lightning==2.0.0
+ - pip=24.2
+
+ - pip:
+ - torchmetrics
+ - hydra-core
+ - hydra-optuna-sweeper
+ - hydra-colorlog
+ - rich
+ - pytest
+ - rootutils
+ - wandb
+ - aim
+ - gradio
+ - image-dehazer
+ - thop
\ No newline at end of file
diff --git a/notebooks/.gitkeep b/notebooks/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..300ebf04f0594f1c517e7017d3697490009b203f
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,25 @@
+[tool.pytest.ini_options]
+addopts = [
+ "--color=yes",
+ "--durations=0",
+ "--strict-markers",
+ "--doctest-modules",
+]
+filterwarnings = [
+ "ignore::DeprecationWarning",
+ "ignore::UserWarning",
+]
+log_cli = "True"
+markers = [
+ "slow: slow tests",
+]
+minversion = "6.0"
+testpaths = "tests/"
+
+[tool.coverage.report]
+exclude_lines = [
+ "pragma: nocover",
+ "raise NotImplementedError",
+ "raise NotImplementedError()",
+ "if __name__ == .__main__.:",
+]
diff --git a/scripts/schedule.sh b/scripts/schedule.sh
new file mode 100644
index 0000000000000000000000000000000000000000..44b3da1116ef4d54e9acffee7d639d549e136d45
--- /dev/null
+++ b/scripts/schedule.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+# Schedule execution of many runs
+# Run from root folder with: bash scripts/schedule.sh
+
+python src/train.py trainer.max_epochs=5 logger=csv
+
+python src/train.py trainer.max_epochs=10 logger=csv
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..40ad3e19f3be584f5b79248f5185634deb15dc4b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+
+from setuptools import find_packages, setup
+
+setup(
+ name="src",
+ version="0.0.1",
+ description="Describe Your Cool Project",
+ author="",
+ author_email="",
+ url="https://github.com/XavierJiezou/cloudseg",
+ install_requires=["lightning", "hydra-core"],
+ packages=find_packages(),
+ # use this to customize global commands available in the terminal after installing the package
+ entry_points={
+ "console_scripts": [
+ "train_command = src.train:main",
+ "eval_command = src.eval:main",
+ ]
+ },
+)
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data/components/__init__.py b/src/data/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data/components/celeba.py b/src/data/components/celeba.py
new file mode 100644
index 0000000000000000000000000000000000000000..826a1a45ecd6eceaa0e8ce5ec9336d564eb2e85b
--- /dev/null
+++ b/src/data/components/celeba.py
@@ -0,0 +1,234 @@
+import json
+import os
+import random
+
+import albumentations
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class DalleTransformerPreprocessor(object):
+ def __init__(self,
+ size=256,
+ phase='train',
+ additional_targets=None):
+
+ self.size = size
+ self.phase = phase
+ # ddc: following dalle to use randomcrop
+ self.train_preprocessor = albumentations.Compose([albumentations.RandomCrop(height=size, width=size)],
+ additional_targets=additional_targets)
+ self.val_preprocessor = albumentations.Compose([albumentations.CenterCrop(height=size, width=size)],
+ additional_targets=additional_targets)
+
+
+ def __call__(self, image, **kargs):
+ """
+ image: PIL.Image
+ """
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image.astype(np.uint8))
+
+ w, h = image.size
+ s_min = min(h, w)
+
+ if self.phase == 'train':
+ off_h = int(random.uniform(3*(h-s_min)//8, max(3*(h-s_min)//8+1, 5*(h-s_min)//8)))
+ off_w = int(random.uniform(3*(w-s_min)//8, max(3*(w-s_min)//8+1, 5*(w-s_min)//8)))
+
+ image = image.crop((off_w, off_h, off_w + s_min, off_h + s_min))
+
+ # resize image
+ t_max = min(s_min, round(9/8*self.size))
+ t_max = max(t_max, self.size)
+ t = int(random.uniform(self.size, t_max+1))
+ image = image.resize((t, t))
+ image = np.array(image).astype(np.uint8)
+ image = self.train_preprocessor(image=image)
+ else:
+ if w < h:
+ w_ = self.size
+ h_ = int(h * w_/w)
+ else:
+ h_ = self.size
+ w_ = int(w * h_/h)
+ image = image.resize((w_, h_))
+ image = np.array(image).astype(np.uint8)
+ image = self.val_preprocessor(image=image)
+ return image
+
+
+class CelebA(Dataset):
+
+ """
+ This Dataset can be used for:
+ - image-only: setting 'conditions' = []
+ - image and multi-modal 'conditions': setting conditions as the list of modalities you need
+
+ To toggle between 256 and 512 image resolution, simply change the 'image_folder'
+ """
+
+ def __init__(
+ self,
+ phase='train',
+ size=512,
+ test_dataset_size=3000,
+ conditions=['seg_mask', 'text', 'sketch'],
+ image_folder='data/celeba/image/image_512_downsampled_from_hq_1024',
+ text_file='data/celeba/text/captions_hq_beard_and_age_2022-08-19.json',
+ mask_folder='data/celeba/mask/CelebAMask-HQ-mask-color-palette_32_nearest_downsampled_from_hq_512_one_hot_2d_tensor',
+ sketch_folder='data/celeba/sketch/sketch_1x1024_tensor',
+ ):
+ self.transform = DalleTransformerPreprocessor(size=size, phase=phase)
+ self.conditions = conditions
+
+ self.image_folder = image_folder
+
+ # conditions directory
+ self.text_file = text_file
+ with open(self.text_file, 'r') as f:
+ self.text_file_content = json.load(f)
+ if 'seg_mask' in self.conditions:
+ self.mask_folder = mask_folder
+ if 'sketch' in self.conditions:
+ self.sketch_folder = sketch_folder
+
+ # list of valid image names & train test split
+ self.image_name_list = list(self.text_file_content.keys())
+
+ # train test split
+ if phase == 'train':
+ self.image_name_list = self.image_name_list[:-test_dataset_size]
+ elif phase == 'test':
+ self.image_name_list = self.image_name_list[-test_dataset_size:]
+ else:
+ raise NotImplementedError
+ self.num = len(self.image_name_list)
+
+ def __len__(self):
+ return self.num
+
+ def __getitem__(self, index):
+
+ # ---------- (1) get image ----------
+ image_name = self.image_name_list[index]
+ image_path = os.path.join(self.image_folder, image_name)
+ image = Image.open(image_path).convert('RGB')
+ image = np.array(image).astype(np.uint8)
+ image = self.transform(image=image)['image']
+ image = image.astype(np.float32)/127.5 - 1.0
+
+ # record into data entry
+ if len(self.conditions) == 1:
+ data = {
+ 'image': image,
+ }
+ else:
+ data = {
+ 'image': image,
+ 'conditions': {}
+ }
+
+ # ---------- (2) get text ----------
+ if 'text' in self.conditions:
+ text = self.text_file_content[image_name]["Beard_and_Age"].lower()
+ # record into data entry
+ if len(self.conditions) == 1:
+ data['caption'] = text
+ else:
+ data['conditions']['text'] = text
+
+ # ---------- (3) get mask ----------
+ if 'seg_mask' in self.conditions:
+ mask_idx = image_name.split('.')[0]
+ mask_name = f'{mask_idx}.pt'
+ mask_path = os.path.join(self.mask_folder, mask_name)
+ mask_one_hot_tensor = torch.load(mask_path)
+
+ # record into data entry
+ if len(self.conditions) == 1:
+ data['seg_mask'] = mask_one_hot_tensor
+ else:
+ data['conditions']['seg_mask'] = mask_one_hot_tensor
+
+ # ---------- (4) get sketch ----------
+ if 'sketch' in self.conditions:
+ sketch_idx = image_name.split('.')[0]
+ sketch_name = f'{sketch_idx}.pt'
+ sketch_path = os.path.join(self.sketch_folder, sketch_name)
+ sketch_one_hot_tensor = torch.load(sketch_path)
+
+ # record into data entry
+ if len(self.conditions) == 1:
+ data['sketch'] = sketch_one_hot_tensor
+ else:
+ data['conditions']['sketch'] = sketch_one_hot_tensor
+ data["image_name"] = image_name.split('.')[0]
+ return data
+
+
+if __name__ == '__main__':
+ # The caption file only has 29999 captions: https://github.com/ziqihuangg/CelebA-Dialog/issues/1
+
+ # Testing for `phase`
+ train_dataset = CelebA(phase="train")
+ test_dataset = CelebA(phase="test")
+ assert len(train_dataset)==26999
+ assert len(test_dataset)==3000
+
+ # Testing for `size`
+ size_512 = CelebA(size=512)
+ assert size_512[0]['image'].shape == (512, 512, 3)
+ assert size_512[0]["conditions"]['seg_mask'].shape == (19, 1024)
+ assert size_512[0]["conditions"]['sketch'].shape == (1, 1024)
+ size_512 = CelebA(size=256)
+ assert size_512[0]['image'].shape == (256, 256, 3)
+ assert size_512[0]["conditions"]['seg_mask'].shape == (19, 1024)
+ assert size_512[0]["conditions"]['sketch'].shape == (1, 1024)
+
+ # Testing for `conditions`
+ dataset = CelebA(conditions = ['seg_mask', 'text', 'sketch'])
+ image = dataset[0]["image"]
+ seg_mask= dataset[0]["conditions"]['seg_mask']
+ sketch = dataset[0]["conditions"]['sketch']
+ text = dataset[0]["conditions"]['text']
+ # show image, seg_mask, sketch in 3x3 grid, and text in title
+ fig, ax = plt.subplots(1, 3, figsize=(12, 4))
+
+ # Show image
+ ax[0].imshow((image + 1) / 2)
+ ax[0].set_title('Image')
+ ax[0].axis('off')
+
+ # # Show segmentation mask
+ seg_mask = torch.argmax(seg_mask, dim=0).reshape(32, 32).numpy().astype(np.uint8)
+ # resize to 512x512 using nearest neighbor interpolation
+ seg_mask = Image.fromarray(seg_mask).resize((512, 512), Image.NEAREST)
+ seg_mask = np.array(seg_mask)
+ ax[1].imshow(seg_mask, cmap='tab20')
+ ax[1].set_title('Segmentation Mask')
+ ax[1].axis('off')
+
+ # # # Show sketch
+ sketch = sketch.reshape(32, 32).numpy().astype(np.uint8)
+ # resize to 512x512 using nearest neighbor interpolation
+ sketch = Image.fromarray(sketch).resize((512, 512), Image.NEAREST)
+ sketch = np.array(sketch)
+ ax[2].imshow(sketch, cmap='gray')
+ ax[2].set_title('Sketch')
+ ax[2].axis('off')
+
+ # Add title with text
+ fig.suptitle(text, fontsize=16)
+ plt.tight_layout()
+ plt.savefig('celeba_sample.png')
+
+ # save seg_mask with name such as "27000.png, 270001.png, ..., 279999.png" of test dataset to "/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/evaluation/CollDiff/real_mask"
+ from tqdm import tqdm
+ for data in tqdm(test_dataset):
+ mask = torch.argmax(data["conditions"]['seg_mask'], dim=0).reshape(32, 32).numpy().astype(np.uint8)
+ mask = Image.fromarray(mask).resize((512, 512), Image.NEAREST)
+ mask.save(f"/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/evaluation/CollDiff/real_mask/{data['image_name']}.png")
\ No newline at end of file
diff --git a/src/data/components/hrcwhu.py b/src/data/components/hrcwhu.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37d6bb085701221178907e84bec684177ea1fac
--- /dev/null
+++ b/src/data/components/hrcwhu.py
@@ -0,0 +1,137 @@
+import os
+
+import albumentations
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class HRCWHU(Dataset):
+ METAINFO = dict(
+ classes=('clear sky', 'cloud'),
+ palette=((128, 192, 128), (255, 255, 255)),
+ img_size=(3, 256, 256), # C, H, W
+ ann_size=(256, 256), # C, H, W
+ train_size=120,
+ test_size=30,
+ )
+
+ def __init__(self, root, phase, all_transform: albumentations.Compose = None,
+ img_transform: albumentations.Compose = None,
+ ann_transform: albumentations.Compose = None, seed: int = 42):
+ self.root = root
+ self.phase = phase
+ self.all_transform = all_transform
+ self.img_transform = img_transform
+ self.ann_transform = ann_transform
+ self.seed = seed
+ self.data = self.load_data()
+
+
+ def load_data(self):
+ data_list = []
+ split = 'train' if self.phase == 'train' else 'test'
+ split_file = os.path.join(self.root, f'{split}.txt')
+ with open(split_file, 'r') as f:
+ for line in f:
+ image_file = line.strip()
+ img_path = os.path.join(self.root, 'img_dir', split, image_file)
+ ann_path = os.path.join(self.root, 'ann_dir', split, image_file)
+ lac_type = image_file.split('_')[0]
+ data_list.append((img_path, ann_path, lac_type))
+ return data_list
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ img_path, ann_path, lac_type = self.data[idx]
+ img = Image.open(img_path)
+ ann = Image.open(ann_path)
+
+ img = np.array(img)
+ ann = np.array(ann)
+
+ if self.all_transform:
+ albumention = self.all_transform(image=img, mask=ann)
+ img = albumention['image']
+ ann = albumention['mask']
+
+ if self.img_transform:
+ img = self.img_transform(image=img)['image']
+
+ if self.ann_transform:
+ ann = self.ann_transform(image=img)['image']
+
+ # if self.img_transform is not None:
+ # img = self.img_transform(img)
+ # if self.ann_transform is not None:
+ # ann = self.ann_transform(ann)
+ # if self.all_transform is not None:
+ # # 对img和ann实现相同的随机变换操作
+ # # seed_everything(self.seed, workers=True)
+ # # random.seed(self.seed)
+ # # img= self.all_transform(img)
+ # # seed_everything(self.seed, workers=True)
+ # # random.seed(self.seed)
+ # # ann= self.all_transform(ann)
+ # merge = torch.cat((img, ann), dim=0)
+ # merge = self.all_transform(merge)
+ # img = merge[:-1]
+ # ann = merge[-1]
+
+ return {
+ 'img': img,
+ 'ann': np.int64(ann),
+ 'img_path': img_path,
+ 'ann_path': ann_path,
+ 'lac_type': lac_type,
+ }
+
+
+if __name__ == '__main__':
+ import torchvision.transforms as transforms
+ import torch
+
+ # all_transform = transforms.Compose([
+ # transforms.RandomCrop((256, 256)),
+ # ])
+ all_transform = transforms.RandomCrop((256, 256))
+
+ # img_transform = transforms.Compose([
+ # transforms.ToTensor(),
+ # ])
+
+ img_transform = transforms.ToTensor()
+
+ # ann_transform = transforms.Compose([
+ # transforms.PILToTensor(),
+ # ])
+ ann_transform = transforms.PILToTensor()
+
+ train_dataset = HRCWHU(root='data/hrcwhu', phase='train', all_transform=all_transform, img_transform=img_transform,
+ ann_transform=ann_transform)
+ test_dataset = HRCWHU(root='data/hrcwhu', phase='test', all_transform=all_transform, img_transform=img_transform,
+ ann_transform=ann_transform)
+
+ assert len(train_dataset) == train_dataset.METAINFO['train_size']
+ assert len(test_dataset) == test_dataset.METAINFO['test_size']
+
+ train_sample = train_dataset[0]
+ test_sample = test_dataset[0]
+
+ assert train_sample['img'].shape == test_sample['img'].shape == train_dataset.METAINFO['img_size']
+ assert train_sample['ann'].shape == test_sample['ann'].shape == train_dataset.METAINFO['ann_size']
+
+ import matplotlib.pyplot as plt
+
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
+ for train_sample in train_dataset:
+ axs[0].imshow(train_sample['img'].permute(1, 2, 0))
+ axs[0].set_title('Image')
+ axs[1].imshow(torch.tensor(train_dataset.METAINFO['palette'])[train_sample['ann']])
+ axs[1].set_title('Annotation')
+ plt.suptitle(f'Land Cover Type: {train_sample["lac_type"].capitalize()}', y=0.8)
+ plt.tight_layout()
+ plt.savefig('HRCWHU_sample.png', bbox_inches="tight")
+ # break
diff --git a/src/data/components/mnist.py b/src/data/components/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6a1ad6ceb4ea1a5a73060bab69ba2a2752cac4a
--- /dev/null
+++ b/src/data/components/mnist.py
@@ -0,0 +1,51 @@
+import h5py
+import matplotlib.pyplot as plt
+import numpy as np
+from torch.utils.data import Dataset
+from torchvision.transforms import ToTensor
+
+
+class MNIST(Dataset):
+ def __init__(self, h5_file, transform=ToTensor()):
+ self.h5_file = h5_file
+ self.transform = transform
+ # 读取HDF5文件
+ with h5py.File(self.h5_file, 'r') as file:
+ self.data = []
+ self.labels = []
+ for i in range(10):
+ images = file[str(i)][()]
+ for img in images:
+ self.data.append(img)
+ self.labels.append(i)
+ self.data = np.array(self.data)
+ self.labels = np.array(self.labels)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ image = self.data[idx]
+ label = self.labels[idx]
+
+ if self.transform:
+ image = self.transform(image)
+
+ return image, label
+
+
+if __name__ == '__main__':
+ mnist_h5_dataset = MNIST('data/mnist.h5')
+
+ assert len(mnist_h5_dataset) == 70000
+
+ # Display the first 10 images of each digit, along with their labels, in a 10x10 grid
+ fig, axs = plt.subplots(10, 10, figsize=(10, 10))
+ for i in range(10):
+ images = mnist_h5_dataset.data[mnist_h5_dataset.labels == i]
+ for j in range(10):
+ axs[i, j].imshow(images[j], cmap='gray')
+ axs[i, j].axis('off')
+ axs[i, j].set_title(i)
+ plt.tight_layout()
+ plt.savefig("mnist_h5_dataset.png")
diff --git a/src/data/hrcwhu_datamodule.py b/src/data/hrcwhu_datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..9626f557e55cb719d8fbc4c3733b9edffa8bcf1f
--- /dev/null
+++ b/src/data/hrcwhu_datamodule.py
@@ -0,0 +1,164 @@
+from typing import Any, Dict, Optional
+
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from src.data.components.hrcwhu import HRCWHU
+
+
+class HRCWHUDataModule(LightningDataModule):
+ def __init__(
+ self,
+ root: str,
+ train_pipeline: None,
+ val_pipeline: None,
+ test_pipeline: None,
+ seed: int=42,
+ batch_size: int = 1,
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+
+ self.train_dataset: Optional[Dataset] = None
+ self.val_dataset: Optional[Dataset] = None
+ self.test_dataset: Optional[Dataset] = None
+
+ self.batch_size_per_device = batch_size
+
+ @property
+ def num_classes(self) -> int:
+ return len(HRCWHU.METAINFO["classes"])
+
+ def prepare_data(self) -> None:
+ """Download data if needed. Lightning ensures that `self.prepare_data()` is called only
+ within a single process on CPU, so you can safely add your downloading logic within. In
+ case of multi-node training, the execution of this hook depends upon
+ `self.prepare_data_per_node()`.
+
+ Do not use it to assign state (self.x = y).
+ """
+ # train
+ HRCWHU(
+ root=self.hparams.root,
+ phase="train",
+ **self.hparams.train_pipeline,
+ seed=self.hparams.seed,
+ )
+
+ # val or test
+ HRCWHU(
+ root=self.hparams.root,
+ phase="test",
+ **self.hparams.test_pipeline,
+ seed=self.hparams.seed,
+ )
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
+
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
+ `self.setup()` once the data is prepared and available for use.
+
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
+ """
+ # Divide batch size by the number of devices.
+ if self.trainer is not None:
+ if self.hparams.batch_size % self.trainer.world_size != 0:
+ raise RuntimeError(
+ f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
+ )
+ self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
+
+ # load and split datasets only if not loaded already
+ if not self.train_dataset and not self.val_dataset and not self.test_dataset:
+ self.train_dataset = HRCWHU(
+ root=self.hparams.root,
+ phase="train",
+ **self.hparams.train_pipeline,
+ seed=self.hparams.seed,
+ )
+
+ self.val_dataset = self.test_dataset = HRCWHU(
+ root=self.hparams.root,
+ phase="test",
+ **self.hparams.test_pipeline,
+ seed=self.hparams.seed,
+ )
+
+ def train_dataloader(self) -> DataLoader[Any]:
+ """Create and return the train dataloader.
+
+ :return: The train dataloader.
+ """
+ return DataLoader(
+ dataset=self.train_dataset,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ persistent_workers=self.hparams.persistent_workers,
+ shuffle=True,
+ )
+
+ def val_dataloader(self) -> DataLoader[Any]:
+ """Create and return the validation dataloader.
+
+ :return: The validation dataloader.
+ """
+ return DataLoader(
+ dataset=self.val_dataset,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ persistent_workers=self.hparams.persistent_workers,
+ shuffle=False,
+ )
+
+ def test_dataloader(self) -> DataLoader[Any]:
+ """Create and return the test dataloader.
+
+ :return: The test dataloader.
+ """
+ return DataLoader(
+ dataset=self.test_dataset,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ persistent_workers=self.hparams.persistent_workers,
+ shuffle=False,
+ )
+
+ def teardown(self, stage: Optional[str] = None) -> None:
+ """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
+ `trainer.test()`, and `trainer.predict()`.
+
+ :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
+ Defaults to ``None``.
+ """
+ pass
+
+ def state_dict(self) -> Dict[Any, Any]:
+ """Called when saving a checkpoint. Implement to generate and save the datamodule state.
+
+ :return: A dictionary containing the datamodule state that you want to save.
+ """
+ return {}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
+ `state_dict()`.
+
+ :param state_dict: The datamodule state returned by `self.state_dict()`.
+ """
+ pass
+
+
+if __name__ == "__main__":
+ _ = HRCWHUDataModule()
diff --git a/src/data/mnist_datamodule.py b/src/data/mnist_datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..312dc0178322ed8908f9420396aae7e5e8333a2c
--- /dev/null
+++ b/src/data/mnist_datamodule.py
@@ -0,0 +1,210 @@
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset, random_split
+from torchvision.transforms import transforms
+
+from src.data.components.mnist import MNIST
+
+
+class MNISTDataModule(LightningDataModule):
+ """`LightningDataModule` for the MNIST dataset.
+
+ The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
+ It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
+ fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
+ while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
+ technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
+ mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
+
+ A `LightningDataModule` implements 7 key methods:
+
+ ```python
+ def prepare_data(self):
+ # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
+ # Download data, pre-process, split, save to disk, etc...
+
+ def setup(self, stage):
+ # Things to do on every process in DDP.
+ # Load data, set variables, etc...
+
+ def train_dataloader(self):
+ # return train dataloader
+
+ def val_dataloader(self):
+ # return validation dataloader
+
+ def test_dataloader(self):
+ # return test dataloader
+
+ def predict_dataloader(self):
+ # return predict dataloader
+
+ def teardown(self, stage):
+ # Called on every process in DDP.
+ # Clean up after fit or test.
+ ```
+
+ This allows you to share a full dataset without explaining how to download,
+ split, transform and process the data.
+
+ Read the docs:
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
+ """
+
+ def __init__(
+ self,
+ data_dir: str = "data/",
+ train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
+ batch_size: int = 64,
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ) -> None:
+ """Initialize a `MNISTDataModule`.
+
+ :param data_dir: The data directory. Defaults to `"data/"`.
+ :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
+ :param batch_size: The batch size. Defaults to `64`.
+ :param num_workers: The number of workers. Defaults to `0`.
+ :param pin_memory: Whether to pin memory. Defaults to `False`.
+ :param persistent_workers: Whether to keep workers alive between data loading. Defaults to `False`.
+ """
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False)
+
+ # data transformations
+ self.transforms = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+ )
+
+ self.data_train: Optional[Dataset] = None
+ self.data_val: Optional[Dataset] = None
+ self.data_test: Optional[Dataset] = None
+
+ self.batch_size_per_device = batch_size
+
+ @property
+ def num_classes(self) -> int:
+ """Get the number of classes.
+
+ :return: The number of MNIST classes (10).
+ """
+ return 10
+
+ def prepare_data(self) -> None:
+ """Download data if needed. Lightning ensures that `self.prepare_data()` is called only
+ within a single process on CPU, so you can safely add your downloading logic within. In
+ case of multi-node training, the execution of this hook depends upon
+ `self.prepare_data_per_node()`.
+
+ Do not use it to assign state (self.x = y).
+ """
+ MNIST(
+ h5_file=f"{self.hparams.data_dir}/mnist.h5",
+ transform=self.transforms,
+ )
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
+
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
+ `self.setup()` once the data is prepared and available for use.
+
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
+ """
+ # Divide batch size by the number of devices.
+ if self.trainer is not None:
+ if self.hparams.batch_size % self.trainer.world_size != 0:
+ raise RuntimeError(
+ f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
+ )
+ self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
+
+ # load and split datasets only if not loaded already
+ if not self.data_train and not self.data_val and not self.data_test:
+ dataset = MNIST(
+ h5_file=f"{self.hparams.data_dir}/mnist.h5",
+ transform=self.transforms,
+ )
+ self.data_train, self.data_val, self.data_test = random_split(
+ dataset=dataset,
+ lengths=self.hparams.train_val_test_split,
+ generator=torch.Generator().manual_seed(42),
+ )
+
+ def train_dataloader(self) -> DataLoader[Any]:
+ """Create and return the train dataloader.
+
+ :return: The train dataloader.
+ """
+ return DataLoader(
+ dataset=self.data_train,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ persistent_workers=self.hparams.persistent_workers,
+ shuffle=True,
+ )
+
+ def val_dataloader(self) -> DataLoader[Any]:
+ """Create and return the validation dataloader.
+
+ :return: The validation dataloader.
+ """
+ return DataLoader(
+ dataset=self.data_val,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ persistent_workers=self.hparams.persistent_workers,
+ shuffle=False,
+ )
+
+ def test_dataloader(self) -> DataLoader[Any]:
+ """Create and return the test dataloader.
+
+ :return: The test dataloader.
+ """
+ return DataLoader(
+ dataset=self.data_test,
+ batch_size=self.batch_size_per_device,
+ num_workers=self.hparams.num_workers,
+ pin_memory=self.hparams.pin_memory,
+ persistent_workers=self.hparams.persistent_workers,
+ shuffle=False,
+ )
+
+ def teardown(self, stage: Optional[str] = None) -> None:
+ """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
+ `trainer.test()`, and `trainer.predict()`.
+
+ :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
+ Defaults to ``None``.
+ """
+ pass
+
+ def state_dict(self) -> Dict[Any, Any]:
+ """Called when saving a checkpoint. Implement to generate and save the datamodule state.
+
+ :return: A dictionary containing the datamodule state that you want to save.
+ """
+ return {}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
+ `state_dict()`.
+
+ :param state_dict: The datamodule state returned by `self.state_dict()`.
+ """
+ pass
+
+
+if __name__ == "__main__":
+ _ = MNISTDataModule()
diff --git a/src/eval.py b/src/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70faae8b59c2d508a070cef7fa85ed39be0a3c1
--- /dev/null
+++ b/src/eval.py
@@ -0,0 +1,99 @@
+from typing import Any, Dict, List, Tuple
+
+import hydra
+import rootutils
+from lightning import LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# ------------------------------------------------------------------------------------ #
+# the setup_root above is equivalent to:
+# - adding project root dir to PYTHONPATH
+# (so you don't need to force user to install project as a package)
+# (necessary before importing any local modules e.g. `from src import utils`)
+# - setting up PROJECT_ROOT environment variable
+# (which is used as a base for paths in "configs/paths/default.yaml")
+# (this way all filepaths are the same no matter where you run the code)
+# - loading environment variables from ".env" in root dir
+#
+# you can remove it if you:
+# 1. either install project as a package or move entry files to project root dir
+# 2. set `root_dir` to "." in "configs/paths/default.yaml"
+#
+# more info: https://github.com/ashleve/rootutils
+# ------------------------------------------------------------------------------------ #
+
+from src.utils import (
+ RankedLogger,
+ extras,
+ instantiate_loggers,
+ log_hyperparameters,
+ task_wrapper,
+)
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+@task_wrapper
+def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """Evaluates given checkpoint on a datamodule testset.
+
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+
+ :param cfg: DictConfig configuration composed by Hydra.
+ :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
+ """
+ assert cfg.ckpt_path
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating loggers...")
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ log_hyperparameters(object_dict)
+
+ log.info("Starting testing!")
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
+
+ # for predictions use trainer.predict(...)
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
+
+ metric_dict = trainer.callback_metrics
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
+def main(cfg: DictConfig) -> None:
+ """Main entry point for evaluation.
+
+ :param cfg: DictConfig configuration composed by Hydra.
+ """
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ extras(cfg)
+
+ evaluate(cfg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/models/base_module.py b/src/models/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..2049662e86ba75db570fa2fca937d0f5da098968
--- /dev/null
+++ b/src/models/base_module.py
@@ -0,0 +1,280 @@
+from typing import Any, Dict, Tuple
+
+import torch
+from lightning import LightningModule
+from torchmetrics import MaxMetric, MeanMetric
+from torchmetrics.classification import Accuracy, F1Score, Precision, Recall
+from torchmetrics.segmentation import MeanIoU, GeneralizedDiceScore
+
+
+class BaseLitModule(LightningModule):
+ """Example of a `LightningModule` for MNIST classification.
+
+ A `LightningModule` implements 8 key methods:
+
+ ```python
+ def __init__(self):
+ # Define initialization code here.
+
+ def setup(self, stage):
+ # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
+ # This hook is called on every process when using DDP.
+
+ def training_step(self, batch, batch_idx):
+ # The complete training step.
+
+ def validation_step(self, batch, batch_idx):
+ # The complete validation step.
+
+ def test_step(self, batch, batch_idx):
+ # The complete test step.
+
+ def predict_step(self, batch, batch_idx):
+ # The complete predict step.
+
+ def configure_optimizers(self):
+ # Define and configure optimizers and LR schedulers.
+ ```
+
+ Docs:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
+ """
+
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ num_classes: int,
+ criterion: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler,
+ compile: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False, ignore=['net'])
+
+ self.net = net
+
+ # metric objects for calculating and averaging accuracy across batches
+ task = "binary" if self.hparams.num_classes==2 else "multiclass"
+
+ self.train_accuracy = Accuracy(task=task, num_classes=num_classes)
+ self.train_precision = Precision(task=task, num_classes=num_classes)
+ self.train_recall = Recall(task=task, num_classes=num_classes)
+ self.train_f1score = F1Score(task=task, num_classes=num_classes)
+ self.train_miou = MeanIoU(num_classes=num_classes)
+ self.train_dice = GeneralizedDiceScore(num_classes=num_classes)
+
+ self.val_accuracy = Accuracy(task=task, num_classes=num_classes)
+ self.val_precision = Precision(task=task, num_classes=num_classes)
+ self.val_recall = Recall(task=task, num_classes=num_classes)
+ self.val_f1score = F1Score(task=task, num_classes=num_classes)
+ self.val_miou = MeanIoU(num_classes=num_classes)
+ self.val_dice = GeneralizedDiceScore(num_classes=num_classes)
+
+ self.test_accuracy = Accuracy(task=task, num_classes=num_classes)
+ self.test_precision = Precision(task=task, num_classes=num_classes)
+ self.test_recall = Recall(task=task, num_classes=num_classes)
+ self.test_f1score = F1Score(task=task, num_classes=num_classes)
+ self.test_miou = MeanIoU(num_classes=num_classes)
+ self.test_dice = GeneralizedDiceScore(num_classes=num_classes)
+
+ # for averaging loss across batches
+ self.train_loss = MeanMetric()
+ self.val_loss = MeanMetric()
+ self.test_loss = MeanMetric()
+
+ # for tracking best so far validation accuracy
+ self.val_miou_best = MaxMetric()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Perform a forward pass through the model `self.net`.
+
+ :param x: A tensor of images.
+ :return: A tensor of logits.
+ """
+ return self.net(x)
+
+ def on_train_start(self) -> None:
+ """Lightning hook that is called when training begins."""
+ # by default lightning executes validation step sanity checks before training starts,
+ # so it's worth to make sure validation metrics don't store results from these checks
+ self.val_loss.reset()
+
+ self.val_accuracy.reset()
+ self.val_precision.reset()
+ self.val_recall.reset()
+ self.val_f1score.reset()
+ self.val_miou.reset()
+ self.val_dice.reset()
+
+ self.val_miou_best.reset()
+
+ def model_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Perform a single model step on a batch of data.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
+
+ :return: A tuple containing (in order):
+ - A tensor of losses.
+ - A tensor of predictions.
+ - A tensor of target labels.
+ """
+ x, y = batch["img"], batch["ann"]
+ logits = self.forward(x)
+ loss = self.hparams.criterion(logits, y)
+ preds = torch.argmax(logits, dim=1)
+ return loss, preds, y
+
+ def training_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
+ ) -> torch.Tensor:
+ """Perform a single training step on a batch of data from the training set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ :return: A tensor of losses between model predictions and targets.
+ """
+ loss, preds, targets = self.model_step(batch)
+
+ # print(preds.shape) # (8, 256, 256)
+ # print(targets.shape) # (8, 256, 256)
+
+ # update and log metrics
+ self.train_loss(loss)
+
+ self.train_accuracy(preds, targets)
+ self.train_precision(preds, targets)
+ self.train_recall(preds, targets)
+ self.train_f1score(preds, targets)
+ self.train_miou(preds, targets)
+ self.train_dice(preds, targets)
+
+ self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
+
+ self.log("train/accuracy", self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("train/precision", self.train_precision, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("train/recall", self.train_recall, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("train/f1score", self.train_f1score, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("train/miou", self.train_miou, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("train/dice", self.train_dice, on_step=False, on_epoch=True, prog_bar=True)
+
+ # return loss or backpropagation will fail
+ return loss
+
+ def on_train_epoch_end(self) -> None:
+ "Lightning hook that is called when a training epoch ends."
+ pass
+
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
+ """Perform a single validation step on a batch of data from the validation set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ loss, preds, targets = self.model_step(batch)
+
+ # update and log metrics
+ self.val_loss(loss)
+
+ self.val_accuracy(preds, targets)
+ self.val_precision(preds, targets)
+ self.val_recall(preds, targets)
+ self.val_f1score(preds, targets)
+ self.val_miou(preds, targets)
+ self.val_dice(preds, targets)
+
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
+
+ self.log("val/accuracy", self.val_accuracy, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/precision", self.val_precision, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/recall", self.val_recall, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/f1score", self.val_f1score, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/miou", self.val_miou, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/dice", self.val_dice, on_step=False, on_epoch=True, prog_bar=True)
+
+ def on_validation_epoch_end(self) -> None:
+ "Lightning hook that is called when a validation epoch ends."
+ miou = self.val_miou.compute() # get current val acc
+ self.val_miou_best(miou) # update best so far val acc
+ # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
+ # otherwise metric would be reset by lightning after each epoch
+ self.log("val/miou_best", self.val_miou_best.compute(), sync_dist=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
+ """Perform a single test step on a batch of data from the test set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ loss, preds, targets = self.model_step(batch)
+
+ # update and log metrics
+ self.test_loss(loss)
+
+ # update and log metrics
+ self.test_accuracy(preds, targets)
+ self.test_precision(preds, targets)
+ self.test_recall(preds, targets)
+ self.test_f1score(preds, targets)
+ self.test_miou(preds, targets)
+ self.test_dice(preds, targets)
+
+ self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
+
+ self.log("test/accuracy", self.test_accuracy, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/f1score", self.test_f1score, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/miou", self.test_miou, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/dice", self.test_dice, on_step=False, on_epoch=True, prog_bar=True)
+
+ def on_test_epoch_end(self) -> None:
+ """Lightning hook that is called when a test epoch ends."""
+ pass
+
+ def setup(self, stage: str) -> None:
+ """Lightning hook that is called at the beginning of fit (train + validate), validate,
+ test, or predict.
+
+ This is a good hook when you need to build models dynamically or adjust something about
+ them. This hook is called on every process when using DDP.
+
+ :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
+ """
+ if self.hparams.compile and stage == "fit":
+ self.net = torch.compile(self.net)
+
+ def configure_optimizers(self) -> Dict[str, Any]:
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
+
+ Examples:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
+
+ :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
+ """
+ optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
+ if self.hparams.scheduler is not None:
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "monitor": "val/loss",
+ "interval": "epoch",
+ "frequency": 1,
+ },
+ }
+ return {"optimizer": optimizer}
+
+
+if __name__ == "__main__":
+ _ = BaseLitModule(None, None, None, None)
diff --git a/src/models/cdnetv2_module.py b/src/models/cdnetv2_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e690d6fe25c403f23a6f9740292351cbe3f136
--- /dev/null
+++ b/src/models/cdnetv2_module.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# @Time : 2024/8/1 下午2:47
+# @Author : xiaoshun
+# @Email : 3038523973@qq.com
+# @File : cdnetv2_module.py
+# @Software: PyCharm
+from typing import Tuple
+
+import torch
+
+import src.models.base_module
+
+
+class CDNetv2LitModule(src.models.base_module.BaseLitModule):
+ def __init__(self,**kwargs):
+ super().__init__(**kwargs)
+
+ def model_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Perform a single model step on a batch of data.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
+
+ :return: A tuple containing (in order):
+ - A tensor of losses.
+ - A tensor of predictions.
+ - A tensor of target labels.
+ """
+ x, y = batch["img"], batch["ann"]
+ logits ,logits_aux = self.forward(x)
+ loss = self.hparams.criterion(logits ,logits_aux, y)
+ preds = torch.argmax(logits, dim=1)
+ return loss, preds, y
diff --git a/src/models/components/cdnetv1.py b/src/models/components/cdnetv1.py
index 7c375b4165da2c3746e400582435aa79a941a2bb..d0f733cfa80bf51db36c05481929d134aad4ffe9 100644
--- a/src/models/components/cdnetv1.py
+++ b/src/models/components/cdnetv1.py
@@ -15,16 +15,9 @@ This is the implementation of CDnetV1 without multi-scale inputs. This implement
import torch
import torch.nn as nn
-import torch.optim as optim
import torch.nn.functional as F
-import torch.backends.cudnn as cudnn
-from torch.utils import data, model_zoo
-from torch.autograd import Variable
-import math
-import numpy as np
affine_par = True
-from torch.autograd import Function
def conv3x3(in_planes, out_planes, stride=1):
diff --git a/src/models/components/cdnetv2.py b/src/models/components/cdnetv2.py
index cb1e59eb7fc702f1ea523505c00f8b73a78eea2b..0f24da3526c26fcea20d05e441df31098d767ec4 100644
--- a/src/models/components/cdnetv2.py
+++ b/src/models/components/cdnetv2.py
@@ -13,18 +13,11 @@ This is the implementation of CDnetV2 without multi-scale inputs. This implement
# nn.GroupNorm
import torch
-from torch import nn
# import torch.nn as nn
-import torch.optim as optim
import torch.nn.functional as F
-import torch.backends.cudnn as cudnn
-from torch.utils import data, model_zoo
-from torch.autograd import Variable
-import math
-import numpy as np
+from torch import nn
affine_par = True
-from torch.autograd import Function
def conv3x3(in_planes, out_planes, stride=1):
diff --git a/src/models/components/cnn.py b/src/models/components/cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c103d61656f384b500ef5c8e6b603914210403
--- /dev/null
+++ b/src/models/components/cnn.py
@@ -0,0 +1,26 @@
+import torch
+from torch import nn
+
+
+class CNN(nn.Module):
+ def __init__(self, dim=32):
+ super(CNN, self).__init__()
+ self.conv1 = nn.Conv2d(1, dim, 5)
+ self.conv2 = nn.Conv2d(dim, dim * 2, 5)
+ self.fc1 = nn.Linear(dim * 2 * 4 * 4, 10)
+
+ def forward(self, x):
+ x = torch.relu(self.conv1(x))
+ x = torch.max_pool2d(x, 2)
+ x = torch.relu(self.conv2(x))
+ x = torch.max_pool2d(x, 2)
+ x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
+ x = self.fc1(x)
+ return x
+
+
+if __name__ == "__main__":
+ input = torch.randn(2, 1, 28, 28)
+ model = CNN()
+ output = model(input)
+ assert output.shape == (2, 10)
\ No newline at end of file
diff --git a/src/models/components/dbnet.py b/src/models/components/dbnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d55855f7d0353bccc537fa5c7c13f44dc2929b4
--- /dev/null
+++ b/src/models/components/dbnet.py
@@ -0,0 +1,680 @@
+# -*- coding: utf-8 -*-
+# @Time : 2024/7/26 上午11:19
+# @Author : xiaoshun
+# @Email : 3038523973@qq.com
+# @File : dbnet.py
+# @Software: PyCharm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+# from models.Transformer.ViT import truncated_normal_
+
+# Decoder细化卷积模块
+class SBR(nn.Module):
+ def __init__(self, in_ch):
+ super(SBR, self).__init__()
+ self.conv1x3 = nn.Sequential(
+ nn.Conv2d(in_ch, in_ch, kernel_size=(1, 3), stride=1, padding=(0, 1)),
+ nn.BatchNorm2d(in_ch),
+ nn.ReLU(True)
+ )
+ self.conv3x1 = nn.Sequential(
+ nn.Conv2d(in_ch, in_ch, kernel_size=(3, 1), stride=1, padding=(1, 0)),
+ nn.BatchNorm2d(in_ch),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ out = self.conv3x1(self.conv1x3(x)) # 先进行1x3的卷积,得到结果并将结果再进行3x1的卷积
+ return out + x
+
+
+# 下采样卷积模块 stage 1,2,3
+class c_stage123(nn.Module):
+ def __init__(self, in_chans, out_chans):
+ super().__init__()
+ self.stage123 = nn.Sequential(
+ nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(out_chans),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_chans),
+ nn.ReLU(),
+ )
+ self.conv1x1_123 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, x):
+ stage123 = self.stage123(x) # 3*3卷积,两倍下采样 3*224*224-->64*112*112
+ max = self.maxpool(x) # 最大值池化,两倍下采样 3*224*224-->3*112*112
+ max = self.conv1x1_123(max) # 1*1卷积 3*112*112-->64*112*112
+ stage123 = stage123 + max # 残差结构,广播机制
+ return stage123
+
+
+# 下采样卷积模块 stage4,5
+class c_stage45(nn.Module):
+ def __init__(self, in_chans, out_chans):
+ super().__init__()
+ self.stage45 = nn.Sequential(
+ nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(out_chans),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_chans),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_chans),
+ nn.ReLU(),
+ )
+ self.conv1x1_45 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, x):
+ stage45 = self.stage45(x) # 3*3卷积模块 2倍下采样
+ max = self.maxpool(x) # 最大值池化,两倍下采样
+ max = self.conv1x1_45(max) # 1*1卷积模块 调整通道数
+ stage45 = stage45 + max # 残差结构
+ return stage45
+
+
+class Identity(nn.Module): # 恒等映射
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+
+# 轻量卷积模块
+class DepthwiseConv2d(nn.Module): # 用于自注意力机制
+ def __init__(self, in_chans, out_chans, kernel_size=1, stride=1, padding=0, dilation=1):
+ super().__init__()
+ # depthwise conv
+ self.depthwise = nn.Conv2d(
+ in_channels=in_chans,
+ out_channels=in_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation, # 深层卷积的膨胀率
+ groups=in_chans # 指定分组卷积的组数
+ )
+ # batch norm
+ self.bn = nn.BatchNorm2d(num_features=in_chans)
+
+ # pointwise conv 逐点卷积
+ self.pointwise = nn.Conv2d(
+ in_channels=in_chans,
+ out_channels=out_chans,
+ kernel_size=1
+ )
+
+ def forward(self, x):
+ x = self.depthwise(x)
+ x = self.bn(x)
+ x = self.pointwise(x)
+ return x
+
+
+# residual skip connection 残差跳跃连接
+class Residual(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, input, **kwargs):
+ x = self.fn(input, **kwargs)
+ return (x + input)
+
+
+# layer norm plus 层归一化
+class PreNorm(nn.Module): # 代表神经网络层
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, input, **kwargs):
+ return self.fn(self.norm(input), **kwargs)
+
+
+# FeedForward层使得representation的表达能力更强
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(in_features=dim, out_features=hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(in_features=hidden_dim, out_features=dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, input):
+ return self.net(input)
+
+
+class ConvAttnetion(nn.Module):
+ '''
+ using the Depth_Separable_Wise Conv2d to produce the q, k, v instead of using Linear Project in ViT
+ '''
+
+ def __init__(self, dim, img_size, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1,
+ dropout=0., last_stage=False):
+ super().__init__()
+ self.last_stage = last_stage
+ self.img_size = img_size
+ inner_dim = dim_head * heads # 512
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** (-0.5)
+
+ pad = (kernel_size - q_stride) // 2
+
+ self.to_q = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=q_stride,
+ padding=pad) # 自注意力机制
+ self.to_k = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=k_stride,
+ padding=pad)
+ self.to_v = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=v_stride,
+ padding=pad)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(
+ in_features=inner_dim,
+ out_features=dim
+ ),
+ nn.Dropout(dropout)
+ ) if project_out else Identity()
+
+ def forward(self, x):
+ b, n, c, h = *x.shape, self.heads # * 星号的作用大概是去掉 tuple 属性吧
+
+ # print(x.shape)
+ # print('+++++++++++++++++++++++++++++++++')
+
+ # if语句内容没有使用
+ if self.last_stage:
+ cls_token = x[:, 0]
+ # print(cls_token.shape)
+ # print('+++++++++++++++++++++++++++++++++')
+ x = x[:, 1:] # 去掉每个数组的第一个元素
+
+ cls_token = rearrange(torch.unsqueeze(cls_token, dim=1), 'b n (h d) -> b h n d', h=h)
+
+ # rearrange:用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作
+ x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size) # [1, 3136, 64]-->1*64*56*56
+ # batch_size,N(通道数),h,w
+
+ q = self.to_q(x) # 1*64*56*56-->1*64*56*56
+ # print(q.shape)
+ # print('++++++++++++++')
+ q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) # 1*64*56*56-->1*1*3136*64
+ # print(q.shape)
+ # print('=====================')
+ # batch_size,head,h*w,dim_head
+
+ k = self.to_k(x) # 操作和q一样
+ k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
+ # batch_size,head,h*w,dim_head
+
+ v = self.to_v(x) ##操作和q一样
+ # print(v.shape)
+ # print('[[[[[[[[[[[[[[[[[[[[[[[[[[[[')
+ v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
+ # print(v.shape)
+ # print(']]]]]]]]]]]]]]]]]]]]]]]]]]]')
+ # batch_size,head,h*w,dim_head
+
+ if self.last_stage:
+ # print(q.shape)
+ # print('================')
+ q = torch.cat([cls_token, q], dim=2)
+ # print(q.shape)
+ # print('++++++++++++++++++')
+ v = torch.cat([cls_token, v], dim=2)
+ k = torch.cat([cls_token, k], dim=2)
+
+ # calculate attention by matmul + scale
+ # permute:(batch_size,head,dim_head,h*w
+ # print(k.shape)
+ # print('++++++++++++++++++++')
+ k = k.permute(0, 1, 3, 2) # 1*1*3136*64-->1*1*64*3136
+ # print(k.shape)
+ # print('====================')
+ attention = (q.matmul(k)) # 1*1*3136*3136
+ # print(attention.shape)
+ # print('--------------------')
+ attention = attention * self.scale # 可以得到一个logit的向量,避免出现梯度下降和梯度爆炸
+ # print(attention.shape)
+ # print('####################')
+ # pass a softmax
+ attention = F.softmax(attention, dim=-1)
+ # print(attention.shape)
+ # print('********************')
+
+ # matmul v
+ # attention.matmul(v):(batch_size,head,h*w,dim_head)
+ # permute:(batch_size,h*w,head,dim_head)
+ out = (attention.matmul(v)).permute(0, 2, 1, 3).reshape(b, n,
+ c) # 1*3136*64 这些操作的目的是将注意力权重和值向量相乘后得到的结果进行重塑,得到一个形状为 (batch size, 序列长度, 值向量或矩阵的维度) 的张量
+
+ # linear project
+ out = self.to_out(out)
+ return out
+
+
+# Reshape Layers
+class Rearrange(nn.Module):
+ def __init__(self, string, h, w):
+ super().__init__()
+ self.string = string
+ self.h = h
+ self.w = w
+
+ def forward(self, input):
+
+ if self.string == 'b c h w -> b (h w) c':
+ N, C, H, W = input.shape
+ # print(input.shape)
+ x = torch.reshape(input, shape=(N, -1, self.h * self.w)).permute(0, 2, 1)
+ # print(x.shape)
+ # print('+++++++++++++++++++')
+ if self.string == 'b (h w) c -> b c h w':
+ N, _, C = input.shape
+ # print(input.shape)
+ x = torch.reshape(input, shape=(N, self.h, self.w, -1)).permute(0, 3, 1, 2)
+ # print(x.shape)
+ # print('=====================')
+ return x
+
+
+# Transformer layers
+class Transformer(nn.Module):
+ def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False):
+ super().__init__()
+ self.layers = nn.ModuleList([ # 管理子模块,参数注册
+ nn.ModuleList([
+ PreNorm(dim=dim, fn=ConvAttnetion(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout,
+ last_stage=last_stage)), # 归一化,重参数化
+ PreNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout))
+ ]) for _ in range(depth)
+ ])
+
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = x + attn(x)
+ x = x + ff(x)
+ return x
+
+
+class DBNet(nn.Module): # 最主要的大函数
+ def __init__(self, img_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3, 3], strides=[4, 2, 2, 2],
+ heads=[1, 3, 6, 6],
+ depth=[1, 2, 10, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4, ):
+ super().__init__()
+
+ assert pool in ['cls', 'mean'], f'pool type must be either cls or mean pooling'
+ self.pool = pool
+ self.dim = dim
+
+ # stage1
+ # k:7 s:4 in: 1, 64, 56, 56 out: 1, 3136, 64
+ self.stage1_conv_embed = nn.Sequential(
+ nn.Conv2d( # 1*3*224*224-->[1, 64, 56, 56]
+ in_channels=in_channels,
+ out_channels=dim,
+ kernel_size=kernels[0],
+ stride=strides[0],
+ padding=2
+ ),
+ Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), # [1, 64, 56, 56]-->[1, 3136, 64]
+ nn.LayerNorm(dim) # 对每个batch归一化
+ )
+
+ self.stage1_transformer = nn.Sequential(
+ Transformer( #
+ dim=dim,
+ img_size=img_size // 4,
+ depth=depth[0], # Transformer层中的编码器和解码器层数。
+ heads=heads[0],
+ dim_head=self.dim, # 它是每个注意力头的维度大小,通常是嵌入维度除以头数。
+ mlp_dim=dim * scale_dim, # mlp_dim:它是Transformer中前馈神经网络的隐藏层维度大小,通常是嵌入维度乘以一个缩放因子。
+ dropout=dropout,
+ # last_stage=last_stage #它是一个标志位,用于表示该Transformer层是否是最后一层。
+ ),
+ Rearrange('b (h w) c -> b c h w', h=img_size // 4, w=img_size // 4)
+ )
+
+ # stage2
+ # k:3 s:2 in: 1, 192, 28, 28 out: 1, 784, 192
+ in_channels = dim
+ scale = heads[1] // heads[0]
+ dim = scale * dim
+
+ self.stage2_conv_embed = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=dim,
+ kernel_size=kernels[1],
+ stride=strides[1],
+ padding=1
+ ),
+ Rearrange('b c h w -> b (h w) c', h=img_size // 8, w=img_size // 8),
+ nn.LayerNorm(dim)
+ )
+
+ self.stage2_transformer = nn.Sequential(
+ Transformer(
+ dim=dim,
+ img_size=img_size // 8,
+ depth=depth[1],
+ heads=heads[1],
+ dim_head=self.dim,
+ mlp_dim=dim * scale_dim,
+ dropout=dropout
+ ),
+ Rearrange('b (h w) c -> b c h w', h=img_size // 8, w=img_size // 8)
+ )
+
+ # stage3
+ in_channels = dim
+ scale = heads[2] // heads[1]
+ dim = scale * dim
+
+ self.stage3_conv_embed = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=dim,
+ kernel_size=kernels[2],
+ stride=strides[2],
+ padding=1
+ ),
+ Rearrange('b c h w -> b (h w) c', h=img_size // 16, w=img_size // 16),
+ nn.LayerNorm(dim)
+ )
+
+ self.stage3_transformer = nn.Sequential(
+ Transformer(
+ dim=dim,
+ img_size=img_size // 16,
+ depth=depth[2],
+ heads=heads[2],
+ dim_head=self.dim,
+ mlp_dim=dim * scale_dim,
+ dropout=dropout
+ ),
+ Rearrange('b (h w) c -> b c h w', h=img_size // 16, w=img_size // 16)
+ )
+
+ # stage4
+ in_channels = dim
+ scale = heads[3] // heads[2]
+ dim = scale * dim
+
+ self.stage4_conv_embed = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=dim,
+ kernel_size=kernels[3],
+ stride=strides[3],
+ padding=1
+ ),
+ Rearrange('b c h w -> b (h w) c', h=img_size // 32, w=img_size // 32),
+ nn.LayerNorm(dim)
+ )
+
+ self.stage4_transformer = nn.Sequential(
+ Transformer(
+ dim=dim, img_size=img_size // 32,
+ depth=depth[3],
+ heads=heads[3],
+ dim_head=self.dim,
+ mlp_dim=dim * scale_dim,
+ dropout=dropout,
+ ),
+ Rearrange('b (h w) c -> b c h w', h=img_size // 32, w=img_size // 32)
+ )
+
+ ### CNN Branch ###
+ self.c_stage1 = c_stage123(in_chans=3, out_chans=64)
+ self.c_stage2 = c_stage123(in_chans=64, out_chans=128)
+ self.c_stage3 = c_stage123(in_chans=128, out_chans=384)
+ self.c_stage4 = c_stage45(in_chans=384, out_chans=512)
+ self.c_stage5 = c_stage45(in_chans=512, out_chans=1024)
+ self.c_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.up_conv1 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1)
+ self.up_conv2 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1)
+
+ ### CTmerge ###
+ self.CTmerge1 = nn.Sequential(
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ )
+ self.CTmerge2 = nn.Sequential(
+ nn.Conv2d(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(128),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(128),
+ nn.ReLU(),
+ )
+ self.CTmerge3 = nn.Sequential(
+ nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(512),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(384),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(384),
+ nn.ReLU(),
+ )
+
+ self.CTmerge4 = nn.Sequential(
+ nn.Conv2d(in_channels=896, out_channels=640, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(640),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=640, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(512),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(512),
+ nn.ReLU(),
+ )
+
+ # decoder
+ self.decoder4 = nn.Sequential(
+ DepthwiseConv2d(
+ in_chans=1408,
+ out_chans=1024,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ DepthwiseConv2d(
+ in_chans=1024,
+ out_chans=512,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ nn.GELU()
+ )
+ self.decoder3 = nn.Sequential(
+ DepthwiseConv2d(
+ in_chans=896,
+ out_chans=512,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ DepthwiseConv2d(
+ in_chans=512,
+ out_chans=384,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ nn.GELU()
+ )
+
+ self.decoder2 = nn.Sequential(
+ DepthwiseConv2d(
+ in_chans=576,
+ out_chans=256,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ DepthwiseConv2d(
+ in_chans=256,
+ out_chans=192,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ nn.GELU()
+ )
+
+ self.decoder1 = nn.Sequential(
+ DepthwiseConv2d(
+ in_chans=256,
+ out_chans=64,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ DepthwiseConv2d(
+ in_chans=64,
+ out_chans=16,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ ),
+ nn.GELU()
+ )
+ self.sbr4 = SBR(512)
+ self.sbr3 = SBR(384)
+ self.sbr2 = SBR(192)
+ self.sbr1 = SBR(16)
+
+ self.head = nn.Conv2d(in_channels=16, out_channels=num_classes, kernel_size=1)
+
+ def forward(self, input):
+ ### encoder ###
+ # stage1 = ts1 cat cs1
+ # t_s1 = self.t_stage1(input)
+ # print(input.shape)
+ # print('++++++++++++++++++++++')
+
+ t_s1 = self.stage1_conv_embed(input) # 1*3*224*224-->1*3136*64
+
+ # print(t_s1.shape)
+ # print('======================')
+
+ t_s1 = self.stage1_transformer(t_s1) # 1*3136*64-->1*64*56*56
+
+ # print(t_s1.shape)
+ # print('----------------------')
+
+ c_s1 = self.c_stage1(input) # 1*3*224*224-->1*64*112*112
+
+ # print(c_s1.shape)
+ # print('!!!!!!!!!!!!!!!!!!!!!!!')
+
+ stage1 = self.CTmerge1(torch.cat([t_s1, self.c_max(c_s1)], dim=1)) # 1*64*56*56 # 拼接两条分支
+
+ # print(stage1.shape)
+ # print('[[[[[[[[[[[[[[[[[[[[[[[')
+
+ # stage2 = ts2 up cs2
+ # t_s2 = self.t_stage2(stage1)
+ t_s2 = self.stage2_conv_embed(stage1) # 1*64*56*56-->1*784*192 # stage2_conv_embed是转化为序列操作
+
+ # print(t_s2.shape)
+ # print('[[[[[[[[[[[[[[[[[[[[[[[')
+ t_s2 = self.stage2_transformer(t_s2) # 1*784*192-->1*192*28*28
+ # print(t_s2.shape)
+ # print('+++++++++++++++++++++++++')
+
+ c_s2 = self.c_stage2(c_s1) # 1*64*112*112-->1*128*56*56
+ stage2 = self.CTmerge2(
+ torch.cat([c_s2, F.interpolate(t_s2, size=c_s2.size()[2:], mode='bilinear', align_corners=True)],
+ dim=1)) # mode='bilinear'表示使用双线性插值 1*128*56*56
+
+ # stage3 = ts3 cat cs3
+ # t_s3 = self.t_stage3(t_s2)
+ t_s3 = self.stage3_conv_embed(t_s2) # 1*192*28*28-->1*196*384
+ # print(t_s3.shape)
+ # print('///////////////////////')
+ t_s3 = self.stage3_transformer(t_s3) # 1*196*384-->1*384*14*14
+ # print(t_s3.shape)
+ # print('....................')
+ c_s3 = self.c_stage3(stage2) # 1*128*56*56-->1*384*28*28
+ stage3 = self.CTmerge3(torch.cat([t_s3, self.c_max(c_s3)], dim=1)) # 1*384*14*14
+
+ # stage4 = ts4 up cs4
+ # t_s4 = self.t_stage4(stage3)
+ t_s4 = self.stage4_conv_embed(stage3) # 1*384*14*14-->1*49*384
+ # print(t_s4.shape)
+ # print(';;;;;;;;;;;;;;;;;;;;;;;')
+ t_s4 = self.stage4_transformer(t_s4) # 1*49*384-->1*384*7*7
+ # print(t_s4.shape)
+ # print('::::::::::::::::::::')
+
+ c_s4 = self.c_stage4(c_s3) # 1*384*28*28-->1*512*14*14
+ stage4 = self.CTmerge4(
+ torch.cat([c_s4, F.interpolate(t_s4, size=c_s4.size()[2:], mode='bilinear', align_corners=True)],
+ dim=1)) # 1*512*14*14
+
+ # cs5
+ c_s5 = self.c_stage5(stage4) # 1*512*14*14-->1*1024*7*7
+
+ ### decoder ###
+ decoder4 = torch.cat([c_s5, t_s4], dim=1) # 1*1408*7*7
+ decoder4 = self.decoder4(decoder4) # 1*1408*7*7-->1*512*7*7
+ decoder4 = F.interpolate(decoder4, size=c_s3.size()[2:], mode='bilinear',
+ align_corners=True) # 1*512*7*7-->1*512*28*28
+ decoder4 = self.sbr4(decoder4) # 1*512*28*28
+ # print(decoder4.shape)
+
+ decoder3 = torch.cat([decoder4, c_s3], dim=1) # 1*896*28*28
+ decoder3 = self.decoder3(decoder3) # 1*384*28*28
+ decoder3 = F.interpolate(decoder3, size=t_s2.size()[2:], mode='bilinear', align_corners=True) # 1*384*28*28
+ decoder3 = self.sbr3(decoder3) # 1*384*28*28
+ # print(decoder3.shape)
+
+ decoder2 = torch.cat([decoder3, t_s2], dim=1) # 1*576*28*28
+ decoder2 = self.decoder2(decoder2) # 1*192*28*28
+ decoder2 = F.interpolate(decoder2, size=c_s1.size()[2:], mode='bilinear', align_corners=True) # 1*192*112*112
+ decoder2 = self.sbr2(decoder2) # 1*192*112*112
+ # print(decoder2.shape)
+
+ decoder1 = torch.cat([decoder2, c_s1], dim=1) # 1*256*112*112
+ decoder1 = self.decoder1(decoder1) # 1*16*112*112
+ # print(decoder1.shape)
+ final = F.interpolate(decoder1, size=input.size()[2:], mode='bilinear', align_corners=True) # 1*16*224*224
+ # print(final.shape)
+ # final = self.sbr1(decoder1)
+ # print(final.shape)
+ final = self.head(final) # 1*3*224*224
+
+ return final
+
+
+if __name__ == '__main__':
+ x = torch.rand(1, 3, 224, 224).cuda()
+ model = DBNet(img_size=224, in_channels=3, num_classes=7).cuda()
+ y = model(x)
+ print(y.shape)
+ # torch.Size([1, 7, 224, 224])
diff --git a/src/models/components/hrcloudnet.py b/src/models/components/hrcloudnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..61682e4b41bf606348e5e12fc98c2bfdae020091
--- /dev/null
+++ b/src/models/components/hrcloudnet.py
@@ -0,0 +1,751 @@
+# 论文地址:https://arxiv.org/abs/2407.07365
+#
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+
+import numpy as np
+import torch
+import torch._utils
+import torch.nn as nn
+import torch.nn.functional as F
+
+BatchNorm2d = nn.BatchNorm2d
+# BN_MOMENTUM = 0.01
+relu_inplace = True
+BN_MOMENTUM = 0.1
+ALIGN_CORNERS = True
+
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+from yacs.config import CfgNode as CN
+import math
+from einops import rearrange
+
+# configs for HRNet48
+HRNET_48 = CN()
+HRNET_48.FINAL_CONV_KERNEL = 1
+
+HRNET_48.STAGE1 = CN()
+HRNET_48.STAGE1.NUM_MODULES = 1
+HRNET_48.STAGE1.NUM_BRANCHES = 1
+HRNET_48.STAGE1.NUM_BLOCKS = [4]
+HRNET_48.STAGE1.NUM_CHANNELS = [64]
+HRNET_48.STAGE1.BLOCK = 'BOTTLENECK'
+HRNET_48.STAGE1.FUSE_METHOD = 'SUM'
+
+HRNET_48.STAGE2 = CN()
+HRNET_48.STAGE2.NUM_MODULES = 1
+HRNET_48.STAGE2.NUM_BRANCHES = 2
+HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
+HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
+HRNET_48.STAGE2.BLOCK = 'BASIC'
+HRNET_48.STAGE2.FUSE_METHOD = 'SUM'
+
+HRNET_48.STAGE3 = CN()
+HRNET_48.STAGE3.NUM_MODULES = 4
+HRNET_48.STAGE3.NUM_BRANCHES = 3
+HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
+HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
+HRNET_48.STAGE3.BLOCK = 'BASIC'
+HRNET_48.STAGE3.FUSE_METHOD = 'SUM'
+
+HRNET_48.STAGE4 = CN()
+HRNET_48.STAGE4.NUM_MODULES = 3
+HRNET_48.STAGE4.NUM_BRANCHES = 4
+HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
+HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
+HRNET_48.STAGE4.BLOCK = 'BASIC'
+HRNET_48.STAGE4.FUSE_METHOD = 'SUM'
+
+HRNET_32 = CN()
+HRNET_32.FINAL_CONV_KERNEL = 1
+
+HRNET_32.STAGE1 = CN()
+HRNET_32.STAGE1.NUM_MODULES = 1
+HRNET_32.STAGE1.NUM_BRANCHES = 1
+HRNET_32.STAGE1.NUM_BLOCKS = [4]
+HRNET_32.STAGE1.NUM_CHANNELS = [64]
+HRNET_32.STAGE1.BLOCK = 'BOTTLENECK'
+HRNET_32.STAGE1.FUSE_METHOD = 'SUM'
+
+HRNET_32.STAGE2 = CN()
+HRNET_32.STAGE2.NUM_MODULES = 1
+HRNET_32.STAGE2.NUM_BRANCHES = 2
+HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
+HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
+HRNET_32.STAGE2.BLOCK = 'BASIC'
+HRNET_32.STAGE2.FUSE_METHOD = 'SUM'
+
+HRNET_32.STAGE3 = CN()
+HRNET_32.STAGE3.NUM_MODULES = 4
+HRNET_32.STAGE3.NUM_BRANCHES = 3
+HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
+HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
+HRNET_32.STAGE3.BLOCK = 'BASIC'
+HRNET_32.STAGE3.FUSE_METHOD = 'SUM'
+
+HRNET_32.STAGE4 = CN()
+HRNET_32.STAGE4.NUM_MODULES = 3
+HRNET_32.STAGE4.NUM_BRANCHES = 4
+HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
+HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
+HRNET_32.STAGE4.BLOCK = 'BASIC'
+HRNET_32.STAGE4.FUSE_METHOD = 'SUM'
+
+HRNET_18 = CN()
+HRNET_18.FINAL_CONV_KERNEL = 1
+
+HRNET_18.STAGE1 = CN()
+HRNET_18.STAGE1.NUM_MODULES = 1
+HRNET_18.STAGE1.NUM_BRANCHES = 1
+HRNET_18.STAGE1.NUM_BLOCKS = [4]
+HRNET_18.STAGE1.NUM_CHANNELS = [64]
+HRNET_18.STAGE1.BLOCK = 'BOTTLENECK'
+HRNET_18.STAGE1.FUSE_METHOD = 'SUM'
+
+HRNET_18.STAGE2 = CN()
+HRNET_18.STAGE2.NUM_MODULES = 1
+HRNET_18.STAGE2.NUM_BRANCHES = 2
+HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
+HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
+HRNET_18.STAGE2.BLOCK = 'BASIC'
+HRNET_18.STAGE2.FUSE_METHOD = 'SUM'
+
+HRNET_18.STAGE3 = CN()
+HRNET_18.STAGE3.NUM_MODULES = 4
+HRNET_18.STAGE3.NUM_BRANCHES = 3
+HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
+HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
+HRNET_18.STAGE3.BLOCK = 'BASIC'
+HRNET_18.STAGE3.FUSE_METHOD = 'SUM'
+
+HRNET_18.STAGE4 = CN()
+HRNET_18.STAGE4.NUM_MODULES = 3
+HRNET_18.STAGE4.NUM_BRANCHES = 4
+HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
+HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
+HRNET_18.STAGE4.BLOCK = 'BASIC'
+HRNET_18.STAGE4.FUSE_METHOD = 'SUM'
+
+
+class PPM(nn.Module):
+ def __init__(self, in_dim, reduction_dim, bins):
+ super(PPM, self).__init__()
+ self.features = []
+ for bin in bins:
+ self.features.append(nn.Sequential(
+ nn.AdaptiveAvgPool2d(bin),
+ nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
+ nn.BatchNorm2d(reduction_dim),
+ nn.ReLU(inplace=True)
+ ))
+ self.features = nn.ModuleList(self.features)
+
+ def forward(self, x):
+ x_size = x.size()
+ out = [x]
+ for f in self.features:
+ out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
+ return torch.cat(out, 1)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=relu_inplace)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out = out + residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=relu_inplace)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ # att = self.downsample(att)
+ out = out + residual
+ out = self.relu(out)
+
+ return out
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method, multi_scale_output=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=relu_inplace)
+
+ def _check_branches(self, num_branches, blocks, num_blocks,
+ num_inchannels, num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+ num_branches, len(num_blocks))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+ num_branches, len(num_channels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+ num_branches, len(num_inchannels))
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
+ momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index], stride, downsample))
+ self.num_inchannels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index],
+ num_channels[branch_index]))
+
+ return nn.Sequential(*layers)
+
+ # 创建平行层
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+ num_branches = self.num_branches # 3
+ num_inchannels = self.num_inchannels # [48, 96, 192]
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_inchannels[i],
+ 1,
+ 1,
+ 0,
+ bias=False),
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ BatchNorm2d(num_outchannels_conv3x3,
+ momentum=BN_MOMENTUM)))
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3, 2, 1, bias=False),
+ BatchNorm2d(num_outchannels_conv3x3,
+ momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=relu_inplace)))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ elif j > i:
+ width_output = x[i].shape[-1]
+ height_output = x[i].shape[-2]
+ y = y + F.interpolate(
+ self.fuse_layers[i][j](x[j]),
+ size=[height_output, width_output],
+ mode='bilinear', align_corners=ALIGN_CORNERS)
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
+
+
+blocks_dict = {
+ 'BASIC': BasicBlock,
+ 'BOTTLENECK': Bottleneck
+}
+
+
+class HRCloudNet(nn.Module):
+
+ def __init__(self, num_classes=2, base_c=48, **kwargs):
+ global ALIGN_CORNERS
+ extra = HRNET_48
+ super(HRCloudNet, self).__init__()
+ ALIGN_CORNERS = True
+ # ALIGN_CORNERS = config.MODEL.ALIGN_CORNERS
+ self.num_classes = num_classes
+ # stem net
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
+ bias=False)
+ self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=relu_inplace)
+
+ self.stage1_cfg = extra['STAGE1']
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+ stage1_out_channel = block.expansion * num_channels
+
+ self.stage2_cfg = extra['STAGE2']
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer(
+ [stage1_out_channel], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ self.stage3_cfg = extra['STAGE3']
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(
+ pre_stage_channels, num_channels) # 只在pre[-1]与cur[-1]之间下采样?
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ self.stage4_cfg = extra['STAGE4']
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
+ num_channels = [
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(
+ pre_stage_channels, num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels, multi_scale_output=True)
+ self.out_conv = OutConv(base_c, num_classes)
+ last_inp_channels = int(np.sum(pre_stage_channels))
+
+ self.corr = Corr(nclass=2)
+ self.proj = nn.Sequential(
+ # 512 32
+ nn.Conv2d(720, 48, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.BatchNorm2d(48),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(0.1),
+ )
+ # self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
+ self.up2 = Up(base_c * 8, base_c * 4, True)
+ self.up3 = Up(base_c * 4, base_c * 2, True)
+ self.up4 = Up(base_c * 2, base_c, True)
+ fea_dim = 720
+ bins = (1, 2, 3, 6)
+ self.ppm = PPM(fea_dim, int(fea_dim / len(bins)), bins)
+ fea_dim *= 2
+ self.cls = nn.Sequential(
+ nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(512),
+ nn.ReLU(inplace=True),
+ nn.Dropout2d(p=0.1),
+ nn.Conv2d(512, 2, kernel_size=1)
+ )
+
+ '''
+ 转换层的作用有两种情况:
+
+ 当前分支数小于之前分支数时,仅对前几个分支进行通道数调整。
+ 当前分支数大于之前分支数时,新建一些转换层,对多余的分支进行下采样,改变通道数以适应后续的连接。
+ 最终,这些转换层会被组合成一个 nn.ModuleList 对象,并在网络的构建过程中使用。
+ 这有助于确保每个分支的通道数在不同阶段之间能够正确匹配,以便进行特征的融合和连接
+ '''
+
+ def _make_transition_layer(
+ self, num_channels_pre_layer, num_channels_cur_layer):
+ # 现在的分支数
+ num_branches_cur = len(num_channels_cur_layer) # 3
+ # 处理前的分支数
+ num_branches_pre = len(num_channels_pre_layer) # 2
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ # 如果当前分支数小于之前分支数,仅针对第一到第二阶段
+ if i < num_branches_pre:
+ # 如果对应层的通道数不一致,则进行转化(
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(nn.Sequential(
+
+ nn.Conv2d(num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ 3,
+ 1,
+ 1,
+ bias=False),
+ BatchNorm2d(
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=relu_inplace)))
+ else:
+ transition_layers.append(None)
+ else: # 在新建层下采样改变通道数
+ conv3x3s = []
+ for j in range(i + 1 - num_branches_pre): # 3
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else inchannels
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(
+ inchannels, outchannels, 3, 2, 1, bias=False),
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=relu_inplace)))
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ '''
+ _make_layer 函数的主要作用是创建一个由多个相同类型的残差块(Residual Block)组成的层。
+ '''
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(inplanes, planes, stride, downsample))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ # 多尺度融合
+ def _make_stage(self, layer_config, num_inchannels,
+ multi_scale_output=True):
+ num_modules = layer_config['NUM_MODULES']
+ num_branches = layer_config['NUM_BRANCHES']
+ num_blocks = layer_config['NUM_BLOCKS']
+ num_channels = layer_config['NUM_CHANNELS']
+ block = blocks_dict[layer_config['BLOCK']]
+ fuse_method = layer_config['FUSE_METHOD']
+
+ modules = []
+ for i in range(num_modules): # 重复4次
+ # multi_scale_output is only used last module
+ if not multi_scale_output and i == num_modules - 1:
+ reset_multi_scale_output = False
+ else:
+ reset_multi_scale_output = True
+ modules.append(
+ HighResolutionModule(num_branches,
+ block,
+ num_blocks,
+ num_inchannels,
+ num_channels,
+ fuse_method,
+ reset_multi_scale_output)
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def forward(self, input, need_fp=True, use_corr=True):
+ # from ipdb import set_trace
+ # set_trace()
+ x = self.conv1(input)
+ x = self.bn1(x)
+ x = self.relu(x)
+ # x_176 = x
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['NUM_BRANCHES']): # 2
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+ # Y1
+ x_list = []
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
+ if self.transition2[i] is not None:
+ if i < self.stage2_cfg['NUM_BRANCHES']:
+ x_list.append(self.transition2[i](y_list[i]))
+ else:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
+ if self.transition3[i] is not None:
+ if i < self.stage3_cfg['NUM_BRANCHES']:
+ x_list.append(self.transition3[i](y_list[i]))
+ else:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ x = self.stage4(x_list)
+ dict_return = {}
+ # Upsampling
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
+
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
+ # x = self.stage3_(x)
+ x[2] = self.up2(x[3], x[2])
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
+ # x = self.stage2_(x)
+ x[1] = self.up3(x[2], x[1])
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
+ x[0] = self.up4(x[1], x[0])
+ xk = torch.cat([x[0], x1, x2, x3], 1)
+ # PPM
+ feat = self.ppm(xk)
+ x = self.cls(feat)
+ # fp分支
+ if need_fp:
+ logits = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
+ # logits = self.out_conv(torch.cat((x, nn.Dropout2d(0.5)(x))))
+ out = logits
+ out_fp = logits
+ if use_corr:
+ proj_feats = self.proj(xk)
+ corr_out = self.corr(proj_feats, out)
+ corr_out = F.interpolate(corr_out, size=(352, 352), mode="bilinear", align_corners=True)
+ dict_return['corr_out'] = corr_out
+ dict_return['out'] = out
+ dict_return['out_fp'] = out_fp
+
+ return dict_return['out']
+
+ out = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
+ if use_corr: # True
+ proj_feats = self.proj(xk)
+ # 计算
+ corr_out = self.corr(proj_feats, out)
+ corr_out = F.interpolate(corr_out, size=(352, 352), mode="bilinear", align_corners=True)
+ dict_return['corr_out'] = corr_out
+ dict_return['out'] = out
+ return dict_return['out']
+ # return x
+
+ def init_weights(self, pretrained='', ):
+ logger.info('=> init weights from normal distribution')
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, std=0.001)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ if os.path.isfile(pretrained):
+ pretrained_dict = torch.load(pretrained)
+ logger.info('=> loading pretrained model {}'.format(pretrained))
+ model_dict = self.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys()}
+ for k, _ in pretrained_dict.items():
+ logger.info(
+ '=> loading {} pretrained model {}'.format(k, pretrained))
+ model_dict.update(pretrained_dict)
+ self.load_state_dict(model_dict)
+
+
+class OutConv(nn.Sequential):
+ def __init__(self, in_channels, num_classes):
+ super(OutConv, self).__init__(
+ nn.Conv2d(720, num_classes, kernel_size=1)
+ )
+
+
+class DoubleConv(nn.Sequential):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ if mid_channels is None:
+ mid_channels = out_channels
+ super(DoubleConv, self).__init__(
+ nn.Conv2d(in_channels + out_channels, mid_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)
+ )
+
+
+class Up(nn.Module):
+ def __init__(self, in_channels, out_channels, bilinear=True):
+ super(Up, self).__init__()
+ if bilinear:
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
+ else:
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
+ self.conv = DoubleConv(in_channels, out_channels)
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ x1 = self.up(x1)
+ # [N, C, H, W]
+ diff_y = x2.size()[2] - x1.size()[2]
+ diff_x = x2.size()[3] - x1.size()[3]
+
+ # padding_left, padding_right, padding_top, padding_bottom
+ x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
+ diff_y // 2, diff_y - diff_y // 2])
+
+ x = torch.cat([x2, x1], dim=1)
+ x = self.conv(x)
+ return x
+
+
+class Corr(nn.Module):
+ def __init__(self, nclass=2):
+ super(Corr, self).__init__()
+ self.nclass = nclass
+ self.conv1 = nn.Conv2d(48, self.nclass, kernel_size=1, stride=1, padding=0, bias=True)
+ self.conv2 = nn.Conv2d(48, self.nclass, kernel_size=1, stride=1, padding=0, bias=True)
+
+ def forward(self, feature_in, out):
+ # in torch.Size([4, 32, 22, 22])
+ # out = [4 2 352 352]
+ h_in, w_in = math.ceil(feature_in.shape[2] / (1)), math.ceil(feature_in.shape[3] / (1))
+ out = F.interpolate(out.detach(), (h_in, w_in), mode='bilinear', align_corners=True)
+ feature = F.interpolate(feature_in, (h_in, w_in), mode='bilinear', align_corners=True)
+ f1 = rearrange(self.conv1(feature), 'n c h w -> n c (h w)')
+ f2 = rearrange(self.conv2(feature), 'n c h w -> n c (h w)')
+ out_temp = rearrange(out, 'n c h w -> n c (h w)')
+ corr_map = torch.matmul(f1.transpose(1, 2), f2) / torch.sqrt(torch.tensor(f1.shape[1]).float())
+ corr_map = F.softmax(corr_map, dim=-1)
+ # out_temp 2 2 484
+ # corr_map 4 484 484
+ out = rearrange(torch.matmul(out_temp, corr_map), 'n c (h w) -> n c h w', h=h_in, w=w_in)
+ # out torch.Size([4, 2, 22, 22])
+ return out
+
+
+if __name__ == '__main__':
+ input = torch.randn(4, 3, 352, 352)
+ cloud = HRCloudNet(num_classes=2)
+ output = cloud(input)
+ print(output.shape)
+ # torch.Size([4, 2, 352, 352]) torch.Size([4, 2, 352, 352]) torch.Size([4, 2, 352, 352])
\ No newline at end of file
diff --git a/src/models/components/lnn.py b/src/models/components/lnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..751a68818abb03a4b27f49b154ac4d4229da6570
--- /dev/null
+++ b/src/models/components/lnn.py
@@ -0,0 +1,23 @@
+import torch
+from torch import nn
+
+
+class LNN(nn.Module):
+ # 创建一个全连接网络用于手写数字识别,并通过一个参数dim控制中间层的维度
+ def __init__(self, dim=32):
+ super(LNN, self).__init__()
+ self.fc1 = nn.Linear(28 * 28, dim)
+ self.fc2 = nn.Linear(dim, 10)
+
+ def forward(self, x):
+ x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
+ x = torch.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+
+if __name__ == "__main__":
+ input = torch.randn(2, 1, 28, 28)
+ model = LNN()
+ output = model(input)
+ assert output.shape == (2, 10)
diff --git a/src/models/components/scnn.py b/src/models/components/scnn.py
index 086c46bc546113b14ecc99c3e29340e0fd72d5e8..171722cfa7647ea8bc0de6de1484c656878aee1f 100644
--- a/src/models/components/scnn.py
+++ b/src/models/components/scnn.py
@@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F
-class SCNNNet(nn.Module):
+class SCNN(nn.Module):
def __init__(self, in_channels=3, num_classes=2, dropout_p=0.5):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1)
@@ -29,7 +29,7 @@ class SCNNNet(nn.Module):
if __name__ == '__main__':
- model = SCNNNet(num_classes=7)
+ model = SCNN(num_classes=7)
fake_img = torch.randn((2, 3, 224, 224))
out = model(fake_img)
print(out.shape)
diff --git a/src/models/components/unet.py b/src/models/components/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e1ed8561d4d02dbac4c9938260ff7878c23bda0
--- /dev/null
+++ b/src/models/components/unet.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class UNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(UNet, self).__init__()
+
+ def conv_block(in_channels, out_channels):
+ return nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True)
+ )
+
+ self.encoder1 = conv_block(in_channels, 64)
+ self.encoder2 = conv_block(64, 128)
+ self.encoder3 = conv_block(128, 256)
+ self.encoder4 = conv_block(256, 512)
+ self.bottleneck = conv_block(512, 1024)
+
+ self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
+ self.decoder4 = conv_block(1024, 512)
+ self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
+ self.decoder3 = conv_block(512, 256)
+ self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
+ self.decoder2 = conv_block(256, 128)
+ self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
+ self.decoder1 = conv_block(128, 64)
+
+ self.final = nn.Conv2d(64, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ enc1 = self.encoder1(x)
+ enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2, stride=2))
+ enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2, stride=2))
+ enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2, stride=2))
+ bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2, stride=2))
+
+ dec4 = self.upconv4(bottleneck)
+ dec4 = torch.cat((dec4, enc4), dim=1)
+ dec4 = self.decoder4(dec4)
+ dec3 = self.upconv3(dec4)
+ dec3 = torch.cat((dec3, enc3), dim=1)
+ dec3 = self.decoder3(dec3)
+ dec2 = self.upconv2(dec3)
+ dec2 = torch.cat((dec2, enc2), dim=1)
+ dec2 = self.decoder2(dec2)
+ dec1 = self.upconv1(dec2)
+ dec1 = torch.cat((dec1, enc1), dim=1)
+ dec1 = self.decoder1(dec1)
+
+ return self.final(dec1)
+
+if __name__ == "__main__":
+ model = UNet(in_channels=3,out_channels=7)
+ fake_img = torch.rand(size=(2,3,224,224))
+ print(fake_img.shape)
+ # torch.Size([2, 3, 224, 224])
+ out = model(fake_img)
+ print(out.shape)
+ # torch.Size([2, 7, 224, 224])
\ No newline at end of file
diff --git a/src/models/components/vae.py b/src/models/components/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bfaecea93df715d9e48a809fe38053f65361fe1
--- /dev/null
+++ b/src/models/components/vae.py
@@ -0,0 +1,144 @@
+from typing import List
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder
+from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(
+ self,
+ double_z: bool = True,
+ z_channels: int = 3,
+ resolution: int = 512,
+ in_channels: int = 3,
+ out_ch: int = 3,
+ ch: int = 128,
+ ch_mult: List = [1, 2, 4, 4],
+ num_res_blocks: int = 2,
+ attn_resolutions: List = [],
+ dropout: float = 0.0,
+ embed_dim: int = 3,
+ ckpt_path: str = None,
+ ignore_keys: List = [],
+ ):
+ super(AutoencoderKL, self).__init__()
+ ddconfig = {
+ "double_z": double_z,
+ "z_channels": z_channels,
+ "resolution": resolution,
+ "in_channels": in_channels,
+ "out_ch": out_ch,
+ "ch": ch,
+ "ch_mult": ch_mult,
+ "num_res_blocks": num_res_blocks,
+ "attn_resolutions": attn_resolutions,
+ "dropout": dropout
+ }
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = nn.Conv2d(
+ 2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+ self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print(f"Deleting key {k} from state_dict.")
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x) # B, C, h, w
+ moments = self.quant_conv(h) # B, 6, h, w
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior # 分布
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input) # 高斯分布
+ if sample_posterior:
+ z = posterior.sample() # 采样
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ last_layer_weight = self.decoder.conv_out.weight
+ return dec, posterior, last_layer_weight
+
+
+if __name__ == '__main__':
+ # Test the input and output shapes of the model
+ model = AutoencoderKL()
+ x = torch.randn(1, 3, 512, 512)
+ dec, posterior, last_layer_weight = model(x)
+
+ assert dec.shape == (1, 3, 512, 512)
+ assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64)
+ assert last_layer_weight.shape == (3, 128, 3, 3)
+
+ # Plot the latent space and the reconstruction from the pretrained model
+ model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt")
+ model.eval()
+ image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg"
+
+ from PIL import Image
+ import numpy as np
+ from src.data.components.celeba import DalleTransformerPreprocessor
+
+ image = Image.open(image_path).convert('RGB')
+ image = np.array(image).astype(np.uint8)
+ import copy
+ original = copy.deepcopy(image)
+ transform = DalleTransformerPreprocessor(size=512, phase='test')
+ image = transform(image=image)['image']
+ image = image.astype(np.float32)/127.5 - 1.0
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
+
+ dec, posterior, last_layer_weight = model(image)
+
+ # original image
+ plt.subplot(1, 3, 1)
+ plt.imshow(original)
+ plt.title("Original")
+ plt.axis("off")
+
+ # sampled image from the latent space
+ plt.subplot(1, 3, 2)
+ x = model.decode(posterior.sample())
+ x = (x+1)/2
+ x = x.squeeze(0).permute(1, 2, 0).cpu()
+ x = x.detach().numpy()
+ x = x.clip(0, 1)
+ x = (x*255).astype(np.uint8)
+ plt.imshow(x)
+ plt.title("Sampled")
+ plt.axis("off")
+
+ # reconstructed image
+ plt.subplot(1, 3, 3)
+ x = dec
+ x = (x+1)/2
+ x = x.squeeze(0).permute(1, 2, 0).cpu()
+ x = x.detach().numpy()
+ x = x.clip(0, 1)
+ x = (x*255).astype(np.uint8)
+ plt.imshow(x)
+ plt.title("Reconstructed")
+ plt.axis("off")
+
+ plt.tight_layout()
+ plt.savefig("vae_reconstruction.png")
diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde934427b7ef0fbf6e4e5001b27bf52a48520a4
--- /dev/null
+++ b/src/models/mnist_module.py
@@ -0,0 +1,217 @@
+from typing import Any, Dict, Tuple
+
+import torch
+from lightning import LightningModule
+from torchmetrics import MaxMetric, MeanMetric
+from torchmetrics.classification.accuracy import Accuracy
+
+
+class MNISTLitModule(LightningModule):
+ """Example of a `LightningModule` for MNIST classification.
+
+ A `LightningModule` implements 8 key methods:
+
+ ```python
+ def __init__(self):
+ # Define initialization code here.
+
+ def setup(self, stage):
+ # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
+ # This hook is called on every process when using DDP.
+
+ def training_step(self, batch, batch_idx):
+ # The complete training step.
+
+ def validation_step(self, batch, batch_idx):
+ # The complete validation step.
+
+ def test_step(self, batch, batch_idx):
+ # The complete test step.
+
+ def predict_step(self, batch, batch_idx):
+ # The complete predict step.
+
+ def configure_optimizers(self):
+ # Define and configure optimizers and LR schedulers.
+ ```
+
+ Docs:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
+ """
+
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler,
+ compile: bool,
+ ) -> None:
+ """Initialize a `MNISTLitModule`.
+
+ :param net: The model to train.
+ :param optimizer: The optimizer to use for training.
+ :param scheduler: The learning rate scheduler to use for training.
+ """
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False, ignore=['net'])
+
+ self.net = net
+
+ # loss function
+ self.criterion = torch.nn.CrossEntropyLoss()
+
+ # metric objects for calculating and averaging accuracy across batches
+ self.train_acc = Accuracy(task="multiclass", num_classes=10)
+ self.val_acc = Accuracy(task="multiclass", num_classes=10)
+ self.test_acc = Accuracy(task="multiclass", num_classes=10)
+
+ # for averaging loss across batches
+ self.train_loss = MeanMetric()
+ self.val_loss = MeanMetric()
+ self.test_loss = MeanMetric()
+
+ # for tracking best so far validation accuracy
+ self.val_acc_best = MaxMetric()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Perform a forward pass through the model `self.net`.
+
+ :param x: A tensor of images.
+ :return: A tensor of logits.
+ """
+ return self.net(x)
+
+ def on_train_start(self) -> None:
+ """Lightning hook that is called when training begins."""
+ # by default lightning executes validation step sanity checks before training starts,
+ # so it's worth to make sure validation metrics don't store results from these checks
+ self.val_loss.reset()
+ self.val_acc.reset()
+ self.val_acc_best.reset()
+
+ def model_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Perform a single model step on a batch of data.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
+
+ :return: A tuple containing (in order):
+ - A tensor of losses.
+ - A tensor of predictions.
+ - A tensor of target labels.
+ """
+ x, y = batch
+ logits = self.forward(x)
+ loss = self.criterion(logits, y)
+ preds = torch.argmax(logits, dim=1)
+ return loss, preds, y
+
+ def training_step(
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
+ ) -> torch.Tensor:
+ """Perform a single training step on a batch of data from the training set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ :return: A tensor of losses between model predictions and targets.
+ """
+ loss, preds, targets = self.model_step(batch)
+
+ # update and log metrics
+ self.train_loss(loss)
+ self.train_acc(preds, targets)
+ self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ # return loss or backpropagation will fail
+ return loss
+
+ def on_train_epoch_end(self) -> None:
+ "Lightning hook that is called when a training epoch ends."
+ pass
+
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
+ """Perform a single validation step on a batch of data from the validation set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ loss, preds, targets = self.model_step(batch)
+
+ # update and log metrics
+ self.val_loss(loss)
+ self.val_acc(preds, targets)
+ self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ def on_validation_epoch_end(self) -> None:
+ "Lightning hook that is called when a validation epoch ends."
+ acc = self.val_acc.compute() # get current val acc
+ self.val_acc_best(acc) # update best so far val acc
+ # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
+ # otherwise metric would be reset by lightning after each epoch
+ self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
+
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
+ """Perform a single test step on a batch of data from the test set.
+
+ :param batch: A batch of data (a tuple) containing the input tensor of images and target
+ labels.
+ :param batch_idx: The index of the current batch.
+ """
+ loss, preds, targets = self.model_step(batch)
+
+ # update and log metrics
+ self.test_loss(loss)
+ self.test_acc(preds, targets)
+ self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
+ self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ def on_test_epoch_end(self) -> None:
+ """Lightning hook that is called when a test epoch ends."""
+ pass
+
+ def setup(self, stage: str) -> None:
+ """Lightning hook that is called at the beginning of fit (train + validate), validate,
+ test, or predict.
+
+ This is a good hook when you need to build models dynamically or adjust something about
+ them. This hook is called on every process when using DDP.
+
+ :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
+ """
+ if self.hparams.compile and stage == "fit":
+ self.net = torch.compile(self.net)
+
+ def configure_optimizers(self) -> Dict[str, Any]:
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
+
+ Examples:
+ https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
+
+ :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
+ """
+ optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
+ if self.hparams.scheduler is not None:
+ scheduler = self.hparams.scheduler(optimizer=optimizer)
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "monitor": "val/loss",
+ "interval": "epoch",
+ "frequency": 1,
+ },
+ }
+ return {"optimizer": optimizer}
+
+
+if __name__ == "__main__":
+ _ = MNISTLitModule(None, None, None, None)
diff --git a/src/train.py b/src/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c272227dbc704d1a9eeab67e7e5dcecb879e87ea
--- /dev/null
+++ b/src/train.py
@@ -0,0 +1,133 @@
+from typing import Any, Dict, List, Optional, Tuple
+
+import hydra
+import lightning as L
+import rootutils
+import torch
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+# ------------------------------------------------------------------------------------ #
+# the setup_root above is equivalent to:
+# - adding project root dir to PYTHONPATH
+# (so you don't need to force user to install project as a package)
+# (necessary before importing any local modules e.g. `from src import utils`)
+# - setting up PROJECT_ROOT environment variable
+# (which is used as a base for paths in "configs/paths/default.yaml")
+# (this way all filepaths are the same no matter where you run the code)
+# - loading environment variables from ".env" in root dir
+#
+# you can remove it if you:
+# 1. either install project as a package or move entry files to project root dir
+# 2. set `root_dir` to "." in "configs/paths/default.yaml"
+#
+# more info: https://github.com/ashleve/rootutils
+# ------------------------------------------------------------------------------------ #
+
+from src.utils import (
+ RankedLogger,
+ extras,
+ get_metric_value,
+ instantiate_callbacks,
+ instantiate_loggers,
+ log_hyperparameters,
+ task_wrapper,
+)
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+@task_wrapper
+def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+ training.
+
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+ failure. Useful for multiruns, saving info about the crash, etc.
+
+ :param cfg: A DictConfig configuration composed by Hydra.
+ :return: A tuple with metrics and dict with all instantiated objects.
+ """
+ # set seed for random number generators in pytorch, numpy and python.random
+ if cfg.get("seed"):
+ L.seed_everything(cfg.seed, workers=True)
+
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+ log.info(f"Instantiating model <{cfg.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+ log.info("Instantiating callbacks...")
+ callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
+
+ log.info("Instantiating loggers...")
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
+
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
+
+ object_dict = {
+ "cfg": cfg,
+ "datamodule": datamodule,
+ "model": model,
+ "callbacks": callbacks,
+ "logger": logger,
+ "trainer": trainer,
+ }
+
+ if logger:
+ log.info("Logging hyperparameters!")
+ log_hyperparameters(object_dict)
+
+ if cfg.get("train"):
+ log.info("Starting training!")
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
+
+ train_metrics = trainer.callback_metrics
+
+ if cfg.get("test"):
+ log.info("Starting testing!")
+ ckpt_path = trainer.checkpoint_callback.best_model_path
+ if ckpt_path == "":
+ log.warning("Best ckpt not found! Using current weights for testing...")
+ ckpt_path = None
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+ log.info(f"Best ckpt path: {ckpt_path}")
+
+ test_metrics = trainer.callback_metrics
+
+ # merge train and test metrics
+ metric_dict = {**train_metrics, **test_metrics}
+
+ return metric_dict, object_dict
+
+
+@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
+def main(cfg: DictConfig) -> Optional[float]:
+ """Main entry point for training.
+
+ :param cfg: DictConfig configuration composed by Hydra.
+ :return: Optional[float] with optimized metric value.
+ """
+ # apply extra utilities
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
+ extras(cfg)
+
+ # train the model
+ metric_dict, _ = train(cfg)
+
+ # safely retrieve metric value for hydra-based hyperparameter optimization
+ metric_value = get_metric_value(
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
+ )
+
+ # return optimized metric
+ return metric_value
+
+
+if __name__ == "__main__":
+ torch.set_float32_matmul_precision('high')
+ main()
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0707ca57ec89fc5f5cb1a023135eeb756a8e1e
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1,5 @@
+from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
+from src.utils.logging_utils import log_hyperparameters
+from src.utils.pylogger import RankedLogger
+from src.utils.rich_utils import enforce_tags, print_config_tree
+from src.utils.utils import extras, get_metric_value, task_wrapper
diff --git a/src/utils/instantiators.py b/src/utils/instantiators.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b9278a465d39565942f862442ebe79549825d7
--- /dev/null
+++ b/src/utils/instantiators.py
@@ -0,0 +1,56 @@
+from typing import List
+
+import hydra
+from lightning import Callback
+from lightning.pytorch.loggers import Logger
+from omegaconf import DictConfig
+
+from src.utils import pylogger
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+ """Instantiates callbacks from config.
+
+ :param callbacks_cfg: A DictConfig object containing callback configurations.
+ :return: A list of instantiated callbacks.
+ """
+ callbacks: List[Callback] = []
+
+ if not callbacks_cfg:
+ log.warning("No callback configs found! Skipping..")
+ return callbacks
+
+ if not isinstance(callbacks_cfg, DictConfig):
+ raise TypeError("Callbacks config must be a DictConfig!")
+
+ for _, cb_conf in callbacks_cfg.items():
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+ """Instantiates loggers from config.
+
+ :param logger_cfg: A DictConfig object containing logger configurations.
+ :return: A list of instantiated loggers.
+ """
+ logger: List[Logger] = []
+
+ if not logger_cfg:
+ log.warning("No logger configs found! Skipping...")
+ return logger
+
+ if not isinstance(logger_cfg, DictConfig):
+ raise TypeError("Logger config must be a DictConfig!")
+
+ for _, lg_conf in logger_cfg.items():
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ return logger
diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..360abcdceec82e551995f756ce6ec3b2d06ae641
--- /dev/null
+++ b/src/utils/logging_utils.py
@@ -0,0 +1,57 @@
+from typing import Any, Dict
+
+from lightning_utilities.core.rank_zero import rank_zero_only
+from omegaconf import OmegaConf
+
+from src.utils import pylogger
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
+ """Controls which config parts are saved by Lightning loggers.
+
+ Additionally saves:
+ - Number of model parameters
+
+ :param object_dict: A dictionary containing the following objects:
+ - `"cfg"`: A DictConfig object containing the main config.
+ - `"model"`: The Lightning model.
+ - `"trainer"`: The Lightning trainer.
+ """
+ hparams = {}
+
+ cfg = OmegaConf.to_container(object_dict["cfg"])
+ model = object_dict["model"]
+ trainer = object_dict["trainer"]
+
+ if not trainer.logger:
+ log.warning("Logger not found! Skipping hyperparameter logging...")
+ return
+
+ hparams["model"] = cfg["model"]
+
+ # save number of model parameters
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+ hparams["model/params/trainable"] = sum(
+ p.numel() for p in model.parameters() if p.requires_grad
+ )
+ hparams["model/params/non_trainable"] = sum(
+ p.numel() for p in model.parameters() if not p.requires_grad
+ )
+
+ hparams["data"] = cfg["data"]
+ hparams["trainer"] = cfg["trainer"]
+
+ hparams["callbacks"] = cfg.get("callbacks")
+ hparams["extras"] = cfg.get("extras")
+
+ hparams["task_name"] = cfg.get("task_name")
+ hparams["tags"] = cfg.get("tags")
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
+ hparams["seed"] = cfg.get("seed")
+
+ # send hparams to all loggers
+ for logger in trainer.loggers:
+ logger.log_hyperparams(hparams)
diff --git a/src/utils/make_h5.py b/src/utils/make_h5.py
new file mode 100644
index 0000000000000000000000000000000000000000..492c16d5f254e423615c8594b51fb8032dd65f12
--- /dev/null
+++ b/src/utils/make_h5.py
@@ -0,0 +1,37 @@
+import os
+
+import h5py
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+
+class MNISTH5Creator:
+ def __init__(self, data_dir, h5_file):
+ self.data_dir = data_dir
+ self.h5_file = h5_file
+
+ def create_h5_file(self):
+ """创建HDF5文件,包含从0到9的子目录中的所有图像数据。"""
+ with h5py.File(self.h5_file, 'w') as h5f:
+ for i in range(10):
+ images = self.load_images_for_digit(i)
+ h5f.create_dataset(name=str(i), data=images)
+ print("HDF5文件已创建.")
+
+ def load_images_for_digit(self, digit):
+ """为给定的数字加载所有图像,并将它们转换为numpy数组。"""
+ digit_folder = os.path.join(self.data_dir, str(digit))
+ images = []
+ for img_name in tqdm(os.listdir(digit_folder), desc=f"Loading images for digit {digit}"):
+ img_path = os.path.join(digit_folder, img_name)
+ img = Image.open(img_path).convert('L')
+ img_array = np.array(img)
+ images.append(img_array)
+ return images
+
+if __name__ == "__main__":
+ data_dir = 'data/mnist'
+ h5_file = 'data/mnist.h5'
+ mnist_h5_creator = MNISTH5Creator(data_dir, h5_file)
+ mnist_h5_creator.create_h5_file()
diff --git a/src/utils/pylogger.py b/src/utils/pylogger.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ee8675ebde11b2a43b0679a03cd88d9268bc71
--- /dev/null
+++ b/src/utils/pylogger.py
@@ -0,0 +1,51 @@
+import logging
+from typing import Mapping, Optional
+
+from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
+
+
+class RankedLogger(logging.LoggerAdapter):
+ """A multi-GPU-friendly python command line logger."""
+
+ def __init__(
+ self,
+ name: str = __name__,
+ rank_zero_only: bool = False,
+ extra: Optional[Mapping[str, object]] = None,
+ ) -> None:
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
+ with their rank prefixed in the log message.
+
+ :param name: The name of the logger. Default is ``__name__``.
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
+ """
+ logger = logging.getLogger(name)
+ super().__init__(logger=logger, extra=extra)
+ self.rank_zero_only = rank_zero_only
+
+ def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
+ occur on that rank/process.
+
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
+ :param msg: The message to log.
+ :param rank: The rank to log at.
+ :param args: Additional args to pass to the underlying logging function.
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
+ """
+ if self.isEnabledFor(level):
+ msg, kwargs = self.process(msg, kwargs)
+ current_rank = getattr(rank_zero_only, "rank", None)
+ if current_rank is None:
+ raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
+ msg = rank_prefixed_message(msg, current_rank)
+ if self.rank_zero_only:
+ if current_rank == 0:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ if rank is None:
+ self.logger.log(level, msg, *args, **kwargs)
+ elif current_rank == rank:
+ self.logger.log(level, msg, *args, **kwargs)
diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeec6806bb1e4a15a04b91b710a546231590ab14
--- /dev/null
+++ b/src/utils/rich_utils.py
@@ -0,0 +1,99 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning_utilities.core.rank_zero import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from src.utils import pylogger
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+@rank_zero_only
+def print_config_tree(
+ cfg: DictConfig,
+ print_order: Sequence[str] = (
+ "data",
+ "model",
+ "callbacks",
+ "logger",
+ "trainer",
+ "paths",
+ "extras",
+ ),
+ resolve: bool = False,
+ save_to_file: bool = False,
+) -> None:
+ """Prints the contents of a DictConfig as a tree structure using the Rich library.
+
+ :param cfg: A DictConfig composed by Hydra.
+ :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
+ "callbacks", "logger", "trainer", "paths", "extras")``.
+ :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
+ :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
+ """
+ style = "dim"
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+ queue = []
+
+ # add fields from `print_order` to queue
+ for field in print_order:
+ queue.append(field) if field in cfg else log.warning(
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
+ )
+
+ # add all the other fields to queue (not specified in `print_order`)
+ for field in cfg:
+ if field not in queue:
+ queue.append(field)
+
+ # generate config tree from queue
+ for field in queue:
+ branch = tree.add(field, style=style, guide_style=style)
+
+ config_group = cfg[field]
+ if isinstance(config_group, DictConfig):
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+ else:
+ branch_content = str(config_group)
+
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+ # print config tree
+ rich.print(tree)
+
+ # save config tree to file
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+ rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+ """Prompts user to input tags from command line if no tags are provided in config.
+
+ :param cfg: A DictConfig composed by Hydra.
+ :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
+ """
+ if not cfg.get("tags"):
+ if "id" in HydraConfig().cfg.hydra.job:
+ raise ValueError("Specify tags before launching a multirun!")
+
+ log.warning("No tags provided in config. Prompting user to input tags...")
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+ tags = [t.strip() for t in tags.split(",") if t != ""]
+
+ with open_dict(cfg):
+ cfg.tags = tags
+
+ log.info(f"Tags: {cfg.tags}")
+
+ if save_to_file:
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+ rich.print(cfg.tags, file=file)
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..02b55765ad3de9441bed931a577ebbb3b669fda4
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,119 @@
+import warnings
+from importlib.util import find_spec
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from omegaconf import DictConfig
+
+from src.utils import pylogger, rich_utils
+
+log = pylogger.RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+ """Applies optional utilities before the task is started.
+
+ Utilities:
+ - Ignoring python warnings
+ - Setting tags from command line
+ - Rich config printing
+
+ :param cfg: A DictConfig object containing the config tree.
+ """
+ # return if no `extras` config
+ if not cfg.get("extras"):
+ log.warning("Extras config not found! ")
+ return
+
+ # disable python warnings
+ if cfg.extras.get("ignore_warnings"):
+ log.info("Disabling python warnings! ")
+ warnings.filterwarnings("ignore")
+
+ # prompt user to input tags from command line if none are provided in the config
+ if cfg.extras.get("enforce_tags"):
+ log.info("Enforcing tags! ")
+ rich_utils.enforce_tags(cfg, save_to_file=True)
+
+ # pretty print config tree using Rich library
+ if cfg.extras.get("print_config"):
+ log.info("Printing config tree with Rich! ")
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+ """Optional decorator that controls the failure behavior when executing the task function.
+
+ This wrapper can be used to:
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+ - save the exception to a `.log` file
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+ - etc. (adjust depending on your needs)
+
+ Example:
+ ```
+ @utils.task_wrapper
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ ...
+ return metric_dict, object_dict
+ ```
+
+ :param task_func: The task function to be wrapped.
+
+ :return: The wrapped task function.
+ """
+
+ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ # execute the task
+ try:
+ metric_dict, object_dict = task_func(cfg=cfg)
+
+ # things to do if exception occurs
+ except Exception as ex:
+ # save exception to `.log` file
+ log.exception("")
+
+ # some hyperparameter combinations might be invalid or cause out-of-memory errors
+ # so when using hparam search plugins like Optuna, you might want to disable
+ # raising the below exception to avoid multirun failure
+ raise ex
+
+ # things to always do after either success or exception
+ finally:
+ # display output dir path in terminal
+ log.info(f"Output dir: {cfg.paths.output_dir}")
+
+ # always close wandb run (even if exception occurs so multirun won't fail)
+ if find_spec("wandb"): # check if wandb is installed
+ import wandb
+
+ if wandb.run:
+ log.info("Closing wandb!")
+ wandb.finish()
+
+ return metric_dict, object_dict
+
+ return wrap
+
+
+def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
+ """Safely retrieves value of the metric logged in LightningModule.
+
+ :param metric_dict: A dict containing metric values.
+ :param metric_name: If provided, the name of the metric to retrieve.
+ :return: If a metric name was provided, the value of the metric.
+ """
+ if not metric_name:
+ log.info("Metric name is None! Skipping metric value retrieval...")
+ return None
+
+ if metric_name not in metric_dict:
+ raise Exception(
+ f"Metric value not found! \n"
+ "Make sure metric name logged in LightningModule is correct!\n"
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+ )
+
+ metric_value = metric_dict[metric_name].item()
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+ return metric_value
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5dea333ca4818bbb1a4fdd5e6e9a70a7ebad1b4
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,107 @@
+"""This file prepares config fixtures for other tests."""
+
+from pathlib import Path
+
+import pytest
+import rootutils
+from hydra import compose, initialize
+from hydra.core.global_hydra import GlobalHydra
+from omegaconf import DictConfig, open_dict
+
+
+@pytest.fixture(scope="package")
+def cfg_train_global() -> DictConfig:
+ """A pytest fixture for setting up a default Hydra DictConfig for training.
+
+ :return: A DictConfig object containing a default Hydra configuration for training.
+ """
+ with initialize(version_base="1.3", config_path="../configs"):
+ cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[])
+
+ # set defaults for all tests
+ with open_dict(cfg):
+ cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root"))
+ cfg.trainer.max_epochs = 1
+ cfg.trainer.limit_train_batches = 0.01
+ cfg.trainer.limit_val_batches = 0.1
+ cfg.trainer.limit_test_batches = 0.1
+ cfg.trainer.accelerator = "cpu"
+ cfg.trainer.devices = 1
+ cfg.data.num_workers = 0
+ cfg.data.pin_memory = False
+ cfg.extras.print_config = False
+ cfg.extras.enforce_tags = False
+ cfg.logger = None
+
+ return cfg
+
+
+@pytest.fixture(scope="package")
+def cfg_eval_global() -> DictConfig:
+ """A pytest fixture for setting up a default Hydra DictConfig for evaluation.
+
+ :return: A DictConfig containing a default Hydra configuration for evaluation.
+ """
+ with initialize(version_base="1.3", config_path="../configs"):
+ cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."])
+
+ # set defaults for all tests
+ with open_dict(cfg):
+ cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root"))
+ cfg.trainer.max_epochs = 1
+ cfg.trainer.limit_test_batches = 0.1
+ cfg.trainer.accelerator = "cpu"
+ cfg.trainer.devices = 1
+ cfg.data.num_workers = 0
+ cfg.data.pin_memory = False
+ cfg.extras.print_config = False
+ cfg.extras.enforce_tags = False
+ cfg.logger = None
+
+ return cfg
+
+
+@pytest.fixture(scope="function")
+def cfg_train(cfg_train_global: DictConfig, tmp_path: Path) -> DictConfig:
+ """A pytest fixture built on top of the `cfg_train_global()` fixture, which accepts a temporary
+ logging path `tmp_path` for generating a temporary logging path.
+
+ This is called by each test which uses the `cfg_train` arg. Each test generates its own temporary logging path.
+
+ :param cfg_train_global: The input DictConfig object to be modified.
+ :param tmp_path: The temporary logging path.
+
+ :return: A DictConfig with updated output and log directories corresponding to `tmp_path`.
+ """
+ cfg = cfg_train_global.copy()
+
+ with open_dict(cfg):
+ cfg.paths.output_dir = str(tmp_path)
+ cfg.paths.log_dir = str(tmp_path)
+
+ yield cfg
+
+ GlobalHydra.instance().clear()
+
+
+@pytest.fixture(scope="function")
+def cfg_eval(cfg_eval_global: DictConfig, tmp_path: Path) -> DictConfig:
+ """A pytest fixture built on top of the `cfg_eval_global()` fixture, which accepts a temporary
+ logging path `tmp_path` for generating a temporary logging path.
+
+ This is called by each test which uses the `cfg_eval` arg. Each test generates its own temporary logging path.
+
+ :param cfg_train_global: The input DictConfig object to be modified.
+ :param tmp_path: The temporary logging path.
+
+ :return: A DictConfig with updated output and log directories corresponding to `tmp_path`.
+ """
+ cfg = cfg_eval_global.copy()
+
+ with open_dict(cfg):
+ cfg.paths.output_dir = str(tmp_path)
+ cfg.paths.log_dir = str(tmp_path)
+
+ yield cfg
+
+ GlobalHydra.instance().clear()
diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py
new file mode 100644
index 0000000000000000000000000000000000000000..0afdba8dc1efd49f9d8c1a47ede62b7e206b99f3
--- /dev/null
+++ b/tests/helpers/package_available.py
@@ -0,0 +1,32 @@
+import platform
+
+import pkg_resources
+from lightning.fabric.accelerators import TPUAccelerator
+
+
+def _package_available(package_name: str) -> bool:
+ """Check if a package is available in your environment.
+
+ :param package_name: The name of the package to be checked.
+
+ :return: `True` if the package is available. `False` otherwise.
+ """
+ try:
+ return pkg_resources.require(package_name) is not None
+ except pkg_resources.DistributionNotFound:
+ return False
+
+
+_TPU_AVAILABLE = TPUAccelerator.is_available()
+
+_IS_WINDOWS = platform.system() == "Windows"
+
+_SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh")
+
+_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed")
+_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale")
+
+_WANDB_AVAILABLE = _package_available("wandb")
+_NEPTUNE_AVAILABLE = _package_available("neptune")
+_COMET_AVAILABLE = _package_available("comet_ml")
+_MLFLOW_AVAILABLE = _package_available("mlflow")
diff --git a/tests/helpers/run_if.py b/tests/helpers/run_if.py
new file mode 100644
index 0000000000000000000000000000000000000000..9703af425129d0225d0aeed20dedc3ed35bc7548
--- /dev/null
+++ b/tests/helpers/run_if.py
@@ -0,0 +1,142 @@
+"""Adapted from:
+
+https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
+"""
+
+import sys
+from typing import Any, Dict, Optional
+
+import pytest
+import torch
+from packaging.version import Version
+from pkg_resources import get_distribution
+from pytest import MarkDecorator
+
+from tests.helpers.package_available import (
+ _COMET_AVAILABLE,
+ _DEEPSPEED_AVAILABLE,
+ _FAIRSCALE_AVAILABLE,
+ _IS_WINDOWS,
+ _MLFLOW_AVAILABLE,
+ _NEPTUNE_AVAILABLE,
+ _SH_AVAILABLE,
+ _TPU_AVAILABLE,
+ _WANDB_AVAILABLE,
+)
+
+
+class RunIf:
+ """RunIf wrapper for conditional skipping of tests.
+
+ Fully compatible with `@pytest.mark`.
+
+ Example:
+
+ ```python
+ @RunIf(min_torch="1.8")
+ @pytest.mark.parametrize("arg1", [1.0, 2.0])
+ def test_wrapper(arg1):
+ assert arg1 > 0
+ ```
+ """
+
+ def __new__(
+ cls,
+ min_gpus: int = 0,
+ min_torch: Optional[str] = None,
+ max_torch: Optional[str] = None,
+ min_python: Optional[str] = None,
+ skip_windows: bool = False,
+ sh: bool = False,
+ tpu: bool = False,
+ fairscale: bool = False,
+ deepspeed: bool = False,
+ wandb: bool = False,
+ neptune: bool = False,
+ comet: bool = False,
+ mlflow: bool = False,
+ **kwargs: Dict[Any, Any],
+ ) -> MarkDecorator:
+ """Creates a new `@RunIf` `MarkDecorator` decorator.
+
+ :param min_gpus: Min number of GPUs required to run test.
+ :param min_torch: Minimum pytorch version to run test.
+ :param max_torch: Maximum pytorch version to run test.
+ :param min_python: Minimum python version required to run test.
+ :param skip_windows: Skip test for Windows platform.
+ :param tpu: If TPU is available.
+ :param sh: If `sh` module is required to run the test.
+ :param fairscale: If `fairscale` module is required to run the test.
+ :param deepspeed: If `deepspeed` module is required to run the test.
+ :param wandb: If `wandb` module is required to run the test.
+ :param neptune: If `neptune` module is required to run the test.
+ :param comet: If `comet` module is required to run the test.
+ :param mlflow: If `mlflow` module is required to run the test.
+ :param kwargs: Native `pytest.mark.skipif` keyword arguments.
+ """
+ conditions = []
+ reasons = []
+
+ if min_gpus:
+ conditions.append(torch.cuda.device_count() < min_gpus)
+ reasons.append(f"GPUs>={min_gpus}")
+
+ if min_torch:
+ torch_version = get_distribution("torch").version
+ conditions.append(Version(torch_version) < Version(min_torch))
+ reasons.append(f"torch>={min_torch}")
+
+ if max_torch:
+ torch_version = get_distribution("torch").version
+ conditions.append(Version(torch_version) >= Version(max_torch))
+ reasons.append(f"torch<{max_torch}")
+
+ if min_python:
+ py_version = (
+ f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
+ )
+ conditions.append(Version(py_version) < Version(min_python))
+ reasons.append(f"python>={min_python}")
+
+ if skip_windows:
+ conditions.append(_IS_WINDOWS)
+ reasons.append("does not run on Windows")
+
+ if tpu:
+ conditions.append(not _TPU_AVAILABLE)
+ reasons.append("TPU")
+
+ if sh:
+ conditions.append(not _SH_AVAILABLE)
+ reasons.append("sh")
+
+ if fairscale:
+ conditions.append(not _FAIRSCALE_AVAILABLE)
+ reasons.append("fairscale")
+
+ if deepspeed:
+ conditions.append(not _DEEPSPEED_AVAILABLE)
+ reasons.append("deepspeed")
+
+ if wandb:
+ conditions.append(not _WANDB_AVAILABLE)
+ reasons.append("wandb")
+
+ if neptune:
+ conditions.append(not _NEPTUNE_AVAILABLE)
+ reasons.append("neptune")
+
+ if comet:
+ conditions.append(not _COMET_AVAILABLE)
+ reasons.append("comet")
+
+ if mlflow:
+ conditions.append(not _MLFLOW_AVAILABLE)
+ reasons.append("mlflow")
+
+ reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
+ return pytest.mark.skipif(
+ condition=any(conditions),
+ reason=f"Requires: [{' + '.join(reasons)}]",
+ **kwargs,
+ )
diff --git a/tests/helpers/run_sh_command.py b/tests/helpers/run_sh_command.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdd2ed633f1185dd7936924616be6a6359a7bca7
--- /dev/null
+++ b/tests/helpers/run_sh_command.py
@@ -0,0 +1,22 @@
+from typing import List
+
+import pytest
+
+from tests.helpers.package_available import _SH_AVAILABLE
+
+if _SH_AVAILABLE:
+ import sh
+
+
+def run_sh_command(command: List[str]) -> None:
+ """Default method for executing shell commands with `pytest` and `sh` package.
+
+ :param command: A list of shell commands as strings.
+ """
+ msg = None
+ try:
+ sh.python(command)
+ except sh.ErrorReturnCode as e:
+ msg = e.stderr.decode()
+ if msg:
+ pytest.fail(msg=msg)
diff --git a/tests/test_configs.py b/tests/test_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7041dc78cc207489255d8618c4a2e75ba74464d
--- /dev/null
+++ b/tests/test_configs.py
@@ -0,0 +1,37 @@
+import hydra
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig
+
+
+def test_train_config(cfg_train: DictConfig) -> None:
+ """Tests the training configuration provided by the `cfg_train` pytest fixture.
+
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ assert cfg_train
+ assert cfg_train.data
+ assert cfg_train.model
+ assert cfg_train.trainer
+
+ HydraConfig().set_config(cfg_train)
+
+ hydra.utils.instantiate(cfg_train.data)
+ hydra.utils.instantiate(cfg_train.model)
+ hydra.utils.instantiate(cfg_train.trainer)
+
+
+def test_eval_config(cfg_eval: DictConfig) -> None:
+ """Tests the evaluation configuration provided by the `cfg_eval` pytest fixture.
+
+ :param cfg_train: A DictConfig containing a valid evaluation configuration.
+ """
+ assert cfg_eval
+ assert cfg_eval.data
+ assert cfg_eval.model
+ assert cfg_eval.trainer
+
+ HydraConfig().set_config(cfg_eval)
+
+ hydra.utils.instantiate(cfg_eval.data)
+ hydra.utils.instantiate(cfg_eval.model)
+ hydra.utils.instantiate(cfg_eval.trainer)
diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf909434415d50c5d181bf8cdc3261b024f52860
--- /dev/null
+++ b/tests/test_datamodules.py
@@ -0,0 +1,38 @@
+from pathlib import Path
+
+import pytest
+import torch
+
+from src.data.celeba_datamodule import MNISTDataModule
+
+
+@pytest.mark.parametrize("batch_size", [32, 128])
+def test_mnist_datamodule(batch_size: int) -> None:
+ """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary
+ attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes
+ correctly match.
+
+ :param batch_size: Batch size of the data to be loaded by the dataloader.
+ """
+ data_dir = "data/"
+
+ dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size)
+ dm.prepare_data()
+
+ assert not dm.data_train and not dm.data_val and not dm.data_test
+ assert Path(data_dir, "MNIST").exists()
+ assert Path(data_dir, "MNIST", "raw").exists()
+
+ dm.setup()
+ assert dm.data_train and dm.data_val and dm.data_test
+ assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()
+
+ num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test)
+ assert num_datapoints == 70_000
+
+ batch = next(iter(dm.train_dataloader()))
+ x, y = batch
+ assert len(x) == batch_size
+ assert len(y) == batch_size
+ assert x.dtype == torch.float32
+ assert y.dtype == torch.int64
diff --git a/tests/test_eval.py b/tests/test_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..423c9d295047ba3c2a8e9306a1b975a09c34de09
--- /dev/null
+++ b/tests/test_eval.py
@@ -0,0 +1,39 @@
+import os
+from pathlib import Path
+
+import pytest
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, open_dict
+
+from src.eval import evaluate
+from src.train import train
+
+
+@pytest.mark.slow
+def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None:
+ """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with
+ `eval.py`.
+
+ :param tmp_path: The temporary logging path.
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ :param cfg_eval: A DictConfig containing a valid evaluation configuration.
+ """
+ assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir
+
+ with open_dict(cfg_train):
+ cfg_train.trainer.max_epochs = 1
+ cfg_train.test = True
+
+ HydraConfig().set_config(cfg_train)
+ train_metric_dict, _ = train(cfg_train)
+
+ assert "last.ckpt" in os.listdir(tmp_path / "checkpoints")
+
+ with open_dict(cfg_eval):
+ cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
+
+ HydraConfig().set_config(cfg_eval)
+ test_metric_dict, _ = evaluate(cfg_eval)
+
+ assert test_metric_dict["test/acc"] > 0.0
+ assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001
diff --git a/tests/test_sweeps.py b/tests/test_sweeps.py
new file mode 100644
index 0000000000000000000000000000000000000000..7856b1551df4e3d4979110ede30076e6a703976f
--- /dev/null
+++ b/tests/test_sweeps.py
@@ -0,0 +1,107 @@
+from pathlib import Path
+
+import pytest
+
+from tests.helpers.run_if import RunIf
+from tests.helpers.run_sh_command import run_sh_command
+
+startfile = "src/train.py"
+overrides = ["logger=[]"]
+
+
+@RunIf(sh=True)
+@pytest.mark.slow
+def test_experiments(tmp_path: Path) -> None:
+ """Test running all available experiment configs with `fast_dev_run=True.`
+
+ :param tmp_path: The temporary logging path.
+ """
+ command = [
+ startfile,
+ "-m",
+ "experiment=glob(*)",
+ "hydra.sweep.dir=" + str(tmp_path),
+ "++trainer.fast_dev_run=true",
+ ] + overrides
+ run_sh_command(command)
+
+
+@RunIf(sh=True)
+@pytest.mark.slow
+def test_hydra_sweep(tmp_path: Path) -> None:
+ """Test default hydra sweep.
+
+ :param tmp_path: The temporary logging path.
+ """
+ command = [
+ startfile,
+ "-m",
+ "hydra.sweep.dir=" + str(tmp_path),
+ "model.optimizer.lr=0.005,0.01",
+ "++trainer.fast_dev_run=true",
+ ] + overrides
+
+ run_sh_command(command)
+
+
+@RunIf(sh=True)
+@pytest.mark.slow
+def test_hydra_sweep_ddp_sim(tmp_path: Path) -> None:
+ """Test default hydra sweep with ddp sim.
+
+ :param tmp_path: The temporary logging path.
+ """
+ command = [
+ startfile,
+ "-m",
+ "hydra.sweep.dir=" + str(tmp_path),
+ "trainer=ddp_sim",
+ "trainer.max_epochs=3",
+ "+trainer.limit_train_batches=0.01",
+ "+trainer.limit_val_batches=0.1",
+ "+trainer.limit_test_batches=0.1",
+ "model.optimizer.lr=0.005,0.01,0.02",
+ ] + overrides
+ run_sh_command(command)
+
+
+@RunIf(sh=True)
+@pytest.mark.slow
+def test_optuna_sweep(tmp_path: Path) -> None:
+ """Test Optuna hyperparam sweeping.
+
+ :param tmp_path: The temporary logging path.
+ """
+ command = [
+ startfile,
+ "-m",
+ "hparams_search=mnist_optuna",
+ "hydra.sweep.dir=" + str(tmp_path),
+ "hydra.sweeper.n_trials=10",
+ "hydra.sweeper.sampler.n_startup_trials=5",
+ "++trainer.fast_dev_run=true",
+ ] + overrides
+ run_sh_command(command)
+
+
+@RunIf(wandb=True, sh=True)
+@pytest.mark.slow
+def test_optuna_sweep_ddp_sim_wandb(tmp_path: Path) -> None:
+ """Test Optuna sweep with wandb logging and ddp sim.
+
+ :param tmp_path: The temporary logging path.
+ """
+ command = [
+ startfile,
+ "-m",
+ "hparams_search=mnist_optuna",
+ "hydra.sweep.dir=" + str(tmp_path),
+ "hydra.sweeper.n_trials=5",
+ "trainer=ddp_sim",
+ "trainer.max_epochs=3",
+ "+trainer.limit_train_batches=0.01",
+ "+trainer.limit_val_batches=0.1",
+ "+trainer.limit_test_batches=0.1",
+ "logger=wandb",
+ ]
+ run_sh_command(command)
diff --git a/tests/test_train.py b/tests/test_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c13ae02c8ae259553e0f0e8192cf054c228172dd
--- /dev/null
+++ b/tests/test_train.py
@@ -0,0 +1,108 @@
+import os
+from pathlib import Path
+
+import pytest
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, open_dict
+
+from src.train import train
+from tests.helpers.run_if import RunIf
+
+
+def test_train_fast_dev_run(cfg_train: DictConfig) -> None:
+ """Run for 1 train, val and test step.
+
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ HydraConfig().set_config(cfg_train)
+ with open_dict(cfg_train):
+ cfg_train.trainer.fast_dev_run = True
+ cfg_train.trainer.accelerator = "cpu"
+ train(cfg_train)
+
+
+@RunIf(min_gpus=1)
+def test_train_fast_dev_run_gpu(cfg_train: DictConfig) -> None:
+ """Run for 1 train, val and test step on GPU.
+
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ HydraConfig().set_config(cfg_train)
+ with open_dict(cfg_train):
+ cfg_train.trainer.fast_dev_run = True
+ cfg_train.trainer.accelerator = "gpu"
+ train(cfg_train)
+
+
+@RunIf(min_gpus=1)
+@pytest.mark.slow
+def test_train_epoch_gpu_amp(cfg_train: DictConfig) -> None:
+ """Train 1 epoch on GPU with mixed-precision.
+
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ HydraConfig().set_config(cfg_train)
+ with open_dict(cfg_train):
+ cfg_train.trainer.max_epochs = 1
+ cfg_train.trainer.accelerator = "gpu"
+ cfg_train.trainer.precision = 16
+ train(cfg_train)
+
+
+@pytest.mark.slow
+def test_train_epoch_double_val_loop(cfg_train: DictConfig) -> None:
+ """Train 1 epoch with validation loop twice per epoch.
+
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ HydraConfig().set_config(cfg_train)
+ with open_dict(cfg_train):
+ cfg_train.trainer.max_epochs = 1
+ cfg_train.trainer.val_check_interval = 0.5
+ train(cfg_train)
+
+
+@pytest.mark.slow
+def test_train_ddp_sim(cfg_train: DictConfig) -> None:
+ """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.
+
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ HydraConfig().set_config(cfg_train)
+ with open_dict(cfg_train):
+ cfg_train.trainer.max_epochs = 2
+ cfg_train.trainer.accelerator = "cpu"
+ cfg_train.trainer.devices = 2
+ cfg_train.trainer.strategy = "ddp_spawn"
+ train(cfg_train)
+
+
+@pytest.mark.slow
+def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None:
+ """Run 1 epoch, finish, and resume for another epoch.
+
+ :param tmp_path: The temporary logging path.
+ :param cfg_train: A DictConfig containing a valid training configuration.
+ """
+ with open_dict(cfg_train):
+ cfg_train.trainer.max_epochs = 1
+
+ HydraConfig().set_config(cfg_train)
+ metric_dict_1, _ = train(cfg_train)
+
+ files = os.listdir(tmp_path / "checkpoints")
+ assert "last.ckpt" in files
+ assert "epoch_000.ckpt" in files
+
+ with open_dict(cfg_train):
+ cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
+ cfg_train.trainer.max_epochs = 2
+
+ metric_dict_2, _ = train(cfg_train)
+
+ files = os.listdir(tmp_path / "checkpoints")
+ assert "epoch_001.ckpt" in files
+ assert "epoch_002.ckpt" not in files
+
+ assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"]
+ assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]
diff --git a/wandb_vis.py b/wandb_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..97424a0fa466a8585b79be8543b8b2cacbd20364
--- /dev/null
+++ b/wandb_vis.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# @Time : 2024/8/3 上午10:46
+# @Author : xiaoshun
+# @Email : 3038523973@qq.com
+# @File : wandb_vis.py
+# @Software: PyCharm
+import argparse
+import os
+import shutil
+from glob import glob
+
+import albumentations as albu
+import numpy as np
+import torch
+import wandb
+from PIL import Image
+from albumentations.pytorch.transforms import ToTensorV2
+from matplotlib import pyplot as plt
+from rich.progress import track
+from thop import profile
+
+from src.data.components.hrcwhu import HRCWHU
+from src.data.hrcwhu_datamodule import HRCWHUDataModule
+from src.models.components.cdnetv1 import CDnetV1
+from src.models.components.cdnetv2 import CDnetV2
+from src.models.components.dbnet import DBNet
+from src.models.components.hrcloudnet import HRCloudNet
+from src.models.components.mcdnet import MCDNet
+from src.models.components.scnn import SCNN
+
+
+class WandbVis:
+ def __init__(self, model_name):
+ self.model_name = model_name
+ self.device = "cuda:1" if torch.cuda.is_available() else "cpu"
+ # self.device = "cpu"
+ self.colors = ((255, 255, 255), (128, 192, 128))
+ self.num_classes = 2
+ self.model = self.load_model()
+ self.dataloader = self.load_dataset()
+ self.macs, self.params = None, None
+ wandb.init(project='model_vis', name=self.model_name)
+
+ def load_weight(self, weight_path: str):
+ weight = torch.load(weight_path, map_location=self.device)
+ state_dict = {}
+ for key, value in weight["state_dict"].items():
+ new_key = key[4:]
+ state_dict[new_key] = value
+ return state_dict
+
+ def load_model_by_model_name(self):
+ if self.model_name == 'dbnet':
+ return DBNet(img_size=256, in_channels=3, num_classes=2).to(self.device)
+ if self.model_name == "cdnetv1":
+ return CDnetV1(num_classes=2).to(self.device)
+ if self.model_name == "cdnetv2":
+ return CDnetV2(num_classes=2).to(self.device)
+
+ if self.model_name == "hrcloud":
+ return HRCloudNet(num_classes=2).to(self.device)
+ if self.model_name == "mcdnet":
+ return MCDNet(in_channels=3, num_classes=2).to(self.device)
+
+ if self.model_name == "scnn":
+ return SCNN(num_classes=2).to(self.device)
+
+ raise ValueError(f"{self.model_name}模型不存在")
+
+ def load_model(self):
+ weight_path = glob(f"logs/train/runs/hrcwhu_{self.model_name}/*/checkpoints/*.ckpt")[0]
+ model = self.load_model_by_model_name()
+ state_dict = self.load_weight(weight_path)
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ def load_dataset(self):
+
+ all_transform = albu.Compose(
+ [
+ albu.Resize(
+ height=HRCWHU.METAINFO["img_size"][1],
+ width=HRCWHU.METAINFO["img_size"][2],
+ always_apply=True
+ )
+ ]
+ )
+ img_transform = albu.Compose([
+ albu.ToFloat(255.0),
+ ToTensorV2()
+ ])
+ ann_transform = None
+ val_pipeline = dict(
+ all_transform=all_transform,
+ img_transform=img_transform,
+ ann_transform=ann_transform,
+ )
+ dataloader = HRCWHUDataModule(
+ root="/home/liujie/liumin/cloudseg/data/hrcwhu",
+ train_pipeline=val_pipeline,
+ val_pipeline=val_pipeline,
+ test_pipeline=val_pipeline,
+ batch_size=1,
+ )
+ dataloader.setup()
+ test_dataloader = dataloader.test_dataloader()
+ return test_dataloader
+ # for data in test_dataloader:
+ # print(data['img'].shape,data['ann'].shape,data['img_path'],data['ann_path'],data['lac_type'])
+ # break
+ # torch.Size([1, 3, 256, 256])
+ # torch.Size([1, 256, 256])
+ # ['/home/liujie/liumin/cloudseg/data/hrcwhu/img_dir/test/barren_30.tif']
+ # ['/home/liujie/liumin/cloudseg/data/hrcwhu/ann_dir/test/barren_30.tif']
+ # ['barren']
+
+ def give_colors_to_mask(self, mask: np.ndarray):
+ """
+ 赋予mask颜色
+ """
+ assert len(mask.shape) == 2, "Value Error,mask的形状为(height,width)"
+ colors_mask = np.zeros((mask.shape[0], mask.shape[1], 3)).astype(np.float32)
+ for color in range(2):
+ segc = (mask == color)
+ colors_mask[:, :, 0] += segc * (self.colors[color][0])
+ colors_mask[:, :, 1] += segc * (self.colors[color][1])
+ colors_mask[:, :, 2] += segc * (self.colors[color][2])
+ return colors_mask
+
+ @torch.no_grad
+ def pred_mask(self, x: torch.Tensor):
+ x = x.to(self.device)
+ self.macs, self.params = profile(self.model, inputs=(x,),verbose=False)
+ logits = self.model(x)
+ if isinstance(logits, tuple):
+ logits = logits[0]
+ fake_mask = torch.argmax(logits, 1).detach().cpu().squeeze(0).numpy()
+ return fake_mask
+
+ def np_pil_np(self, image: np.ndarray, filename="colors_ann"):
+ colors_np = self.give_colors_to_mask(image)
+ pil_np = Image.fromarray(np.uint8(colors_np))
+ return np.array(pil_np)
+
+ def run(self, delete_wadb_log=True):
+ # print(len(self.dataloader))
+ # 30
+ for data in track(self.dataloader):
+ img = data["img"]
+ ann = data["ann"].squeeze(0).numpy()
+ img_path = data["img_path"]
+ fake_mask = self.pred_mask(img)
+
+ colors_ann = self.np_pil_np(ann)
+ colors_fake = self.np_pil_np(fake_mask, "colors_fake")
+ image_name = img_path[0].split(os.path.sep)[-1].split(".")[0]
+
+ plt.subplot(1, 2, 1)
+ plt.axis("off")
+ plt.title("groud true")
+ plt.imshow(colors_ann)
+
+ plt.subplot(1, 2, 2)
+ plt.axis("off")
+ plt.title("predict mask")
+ plt.imshow(colors_fake)
+ wandb.log({image_name: wandb.Image(plt)})
+ wandb.log({"macs":self.macs,"params":self.params})
+ wandb.finish()
+ if delete_wadb_log and os.path.exists("wandb"):
+ shutil.rmtree("wandb")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-name", type=str, default="dbnet")
+ parser.add_argument("--delete-wadb-log", type=bool, default=True)
+ args = parser.parse_args()
+ vis = WandbVis(model_name=args.model_name)
+ vis.run(delete_wadb_log=args.delete_wadb_log)