Spaces:
Runtime error
Runtime error
Commit
·
922c280
1
Parent(s):
b03b420
added dataset scripts
Browse files- Makefile +0 -142
- command.py +40 -0
- config.yaml +13 -0
- requirements.txt +3 -14
- src/__init__.py +9 -0
- src/data/load_dataset.py +73 -0
- src/data/make_dataset.py +120 -22
- src/data/visualize_dataset.py +52 -0
- src/utils.py +39 -0
Makefile
CHANGED
@@ -1,144 +1,2 @@
|
|
1 |
.PHONY: clean data lint requirements sync_data_to_s3 sync_data_from_s3
|
2 |
|
3 |
-
#################################################################################
|
4 |
-
# GLOBALS #
|
5 |
-
#################################################################################
|
6 |
-
|
7 |
-
PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
|
8 |
-
BUCKET = [OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')
|
9 |
-
PROFILE = default
|
10 |
-
PROJECT_NAME = project_name
|
11 |
-
PYTHON_INTERPRETER = python
|
12 |
-
|
13 |
-
ifeq (,$(shell which conda))
|
14 |
-
HAS_CONDA=False
|
15 |
-
else
|
16 |
-
HAS_CONDA=True
|
17 |
-
endif
|
18 |
-
|
19 |
-
#################################################################################
|
20 |
-
# COMMANDS #
|
21 |
-
#################################################################################
|
22 |
-
|
23 |
-
## Install Python Dependencies
|
24 |
-
requirements: test_environment
|
25 |
-
$(PYTHON_INTERPRETER) -m pip install -U pip setuptools wheel
|
26 |
-
$(PYTHON_INTERPRETER) -m pip install -r requirements.txt
|
27 |
-
|
28 |
-
## Make Dataset
|
29 |
-
data: requirements
|
30 |
-
$(PYTHON_INTERPRETER) src/data/make_dataset.py data/raw data/processed
|
31 |
-
|
32 |
-
## Delete all compiled Python files
|
33 |
-
clean:
|
34 |
-
find . -type f -name "*.py[co]" -delete
|
35 |
-
find . -type d -name "__pycache__" -delete
|
36 |
-
|
37 |
-
## Lint using flake8
|
38 |
-
lint:
|
39 |
-
flake8 src
|
40 |
-
|
41 |
-
## Upload Data to S3
|
42 |
-
sync_data_to_s3:
|
43 |
-
ifeq (default,$(PROFILE))
|
44 |
-
aws s3 sync data/ s3://$(BUCKET)/data/
|
45 |
-
else
|
46 |
-
aws s3 sync data/ s3://$(BUCKET)/data/ --profile $(PROFILE)
|
47 |
-
endif
|
48 |
-
|
49 |
-
## Download Data from S3
|
50 |
-
sync_data_from_s3:
|
51 |
-
ifeq (default,$(PROFILE))
|
52 |
-
aws s3 sync s3://$(BUCKET)/data/ data/
|
53 |
-
else
|
54 |
-
aws s3 sync s3://$(BUCKET)/data/ data/ --profile $(PROFILE)
|
55 |
-
endif
|
56 |
-
|
57 |
-
## Set up python interpreter environment
|
58 |
-
create_environment:
|
59 |
-
ifeq (True,$(HAS_CONDA))
|
60 |
-
@echo ">>> Detected conda, creating conda environment."
|
61 |
-
ifeq (3,$(findstring 3,$(PYTHON_INTERPRETER)))
|
62 |
-
conda create --name $(PROJECT_NAME) python=3
|
63 |
-
else
|
64 |
-
conda create --name $(PROJECT_NAME) python=2.7
|
65 |
-
endif
|
66 |
-
@echo ">>> New conda env created. Activate with:\nsource activate $(PROJECT_NAME)"
|
67 |
-
else
|
68 |
-
$(PYTHON_INTERPRETER) -m pip install -q virtualenv virtualenvwrapper
|
69 |
-
@echo ">>> Installing virtualenvwrapper if not already installed.\nMake sure the following lines are in shell startup file\n\
|
70 |
-
export WORKON_HOME=$$HOME/.virtualenvs\nexport PROJECT_HOME=$$HOME/Devel\nsource /usr/local/bin/virtualenvwrapper.sh\n"
|
71 |
-
@bash -c "source `which virtualenvwrapper.sh`;mkvirtualenv $(PROJECT_NAME) --python=$(PYTHON_INTERPRETER)"
|
72 |
-
@echo ">>> New virtualenv created. Activate with:\nworkon $(PROJECT_NAME)"
|
73 |
-
endif
|
74 |
-
|
75 |
-
## Test python environment is setup correctly
|
76 |
-
test_environment:
|
77 |
-
$(PYTHON_INTERPRETER) test_environment.py
|
78 |
-
|
79 |
-
#################################################################################
|
80 |
-
# PROJECT RULES #
|
81 |
-
#################################################################################
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
#################################################################################
|
86 |
-
# Self Documenting Commands #
|
87 |
-
#################################################################################
|
88 |
-
|
89 |
-
.DEFAULT_GOAL := help
|
90 |
-
|
91 |
-
# Inspired by <http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html>
|
92 |
-
# sed script explained:
|
93 |
-
# /^##/:
|
94 |
-
# * save line in hold space
|
95 |
-
# * purge line
|
96 |
-
# * Loop:
|
97 |
-
# * append newline + line to hold space
|
98 |
-
# * go to next line
|
99 |
-
# * if line starts with doc comment, strip comment character off and loop
|
100 |
-
# * remove target prerequisites
|
101 |
-
# * append hold space (+ newline) to line
|
102 |
-
# * replace newline plus comments by `---`
|
103 |
-
# * print line
|
104 |
-
# Separate expressions are necessary because labels cannot be delimited by
|
105 |
-
# semicolon; see <http://stackoverflow.com/a/11799865/1968>
|
106 |
-
.PHONY: help
|
107 |
-
help:
|
108 |
-
@echo "$$(tput bold)Available rules:$$(tput sgr0)"
|
109 |
-
@echo
|
110 |
-
@sed -n -e "/^## / { \
|
111 |
-
h; \
|
112 |
-
s/.*//; \
|
113 |
-
:doc" \
|
114 |
-
-e "H; \
|
115 |
-
n; \
|
116 |
-
s/^## //; \
|
117 |
-
t doc" \
|
118 |
-
-e "s/:.*//; \
|
119 |
-
G; \
|
120 |
-
s/\\n## /---/; \
|
121 |
-
s/\\n/ /g; \
|
122 |
-
p; \
|
123 |
-
}" ${MAKEFILE_LIST} \
|
124 |
-
| LC_ALL='C' sort --ignore-case \
|
125 |
-
| awk -F '---' \
|
126 |
-
-v ncol=$$(tput cols) \
|
127 |
-
-v indent=19 \
|
128 |
-
-v col_on="$$(tput setaf 6)" \
|
129 |
-
-v col_off="$$(tput sgr0)" \
|
130 |
-
'{ \
|
131 |
-
printf "%s%*s%s ", col_on, -indent, $$1, col_off; \
|
132 |
-
n = split($$2, words, " "); \
|
133 |
-
line_length = ncol - indent; \
|
134 |
-
for (i = 1; i <= n; i++) { \
|
135 |
-
line_length -= length(words[i]) + 1; \
|
136 |
-
if (line_length <= 0) { \
|
137 |
-
line_length = ncol - indent - length(words[i]) - 1; \
|
138 |
-
printf "\n%*s ", -indent, " "; \
|
139 |
-
} \
|
140 |
-
printf "%s ", words[i]; \
|
141 |
-
} \
|
142 |
-
printf "\n"; \
|
143 |
-
}' \
|
144 |
-
| more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars')
|
|
|
1 |
.PHONY: clean data lint requirements sync_data_to_s3 sync_data_from_s3
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
command.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
# parser = argparse.ArgumentParser()
|
6 |
+
# parser.add_argument("category")
|
7 |
+
# parser.add_argument("subcommand-args")
|
8 |
+
# args = parser.parse_args()
|
9 |
+
args = sys.argv
|
10 |
+
|
11 |
+
# remove "command.py"
|
12 |
+
args = args[1:]
|
13 |
+
|
14 |
+
# print(args)
|
15 |
+
subcommand = args[0].lower()
|
16 |
+
|
17 |
+
subcommand_args = " ".join(args[1:])
|
18 |
+
if subcommand=="data":
|
19 |
+
command = "py src/data/make_dataset.py "+subcommand_args
|
20 |
+
# print(command)
|
21 |
+
os.system(command)
|
22 |
+
else:
|
23 |
+
print("subcommand not supported.")
|
24 |
+
|
25 |
+
# os.system("py src/__init__.py")
|
26 |
+
"""
|
27 |
+
download the dataset: data download
|
28 |
+
preprocess dataset: data prepare
|
29 |
+
visualize dataset: data show
|
30 |
+
delete raw & interim dataset dir: data delete --cache
|
31 |
+
delete all dataset dir: data delete --all
|
32 |
+
|
33 |
+
|
34 |
+
train model: model train
|
35 |
+
evaluate model: model evaluate
|
36 |
+
inference with model: model predict --image test.jpg --folder images/ -d results/
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
"""
|
config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
raw_dataset_dir: data/raw/
|
2 |
+
interim_dataset_dir: data/interim/
|
3 |
+
processed_dataset_dir: data/processed/
|
4 |
+
|
5 |
+
# forests or pascal-voc
|
6 |
+
dataset: forests
|
7 |
+
|
8 |
+
image_size: 224
|
9 |
+
train_size: 0.8
|
10 |
+
shuffle: False
|
11 |
+
batch_size: 16
|
12 |
+
|
13 |
+
seed: 324
|
requirements.txt
CHANGED
@@ -1,14 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
# external requirements
|
5 |
-
click
|
6 |
-
Sphinx
|
7 |
-
coverage
|
8 |
-
awscli
|
9 |
-
flake8
|
10 |
-
python-dotenv>=0.5.1
|
11 |
-
|
12 |
-
|
13 |
-
# backwards compatibility
|
14 |
-
pathlib2
|
|
|
1 |
+
huggingface_hub
|
2 |
+
comet_ml
|
3 |
+
scikit-image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/__init__.py
CHANGED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils import Config
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
config = Config("config.yaml")
|
5 |
+
# config.raw_dataset_dir = Path(config.raw_dataset_dir)
|
6 |
+
# config.interim_dataset_dir = Path(config.interim_dataset_dir)
|
7 |
+
# config.processed_dataset_dir = Path(config.processed_dataset_dir)
|
8 |
+
|
9 |
+
# print(config)
|
src/data/load_dataset.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os,sys;sys.path.append(os.getcwd())
|
2 |
+
import tensorflow as tf
|
3 |
+
from src import config
|
4 |
+
from src.utils import *
|
5 |
+
from pathlib import Path
|
6 |
+
from glob import glob
|
7 |
+
import sklearn.model_selection
|
8 |
+
from skimage.color import rgb2lab, lab2rgb
|
9 |
+
|
10 |
+
def get_datasets():
|
11 |
+
trainval_dir = Path(config.processed_dataset_dir) / Path("trainval/")
|
12 |
+
test_dir = Path(config.processed_dataset_dir) / Path("test/")
|
13 |
+
|
14 |
+
trainval_paths = glob(str(trainval_dir/Path("*")))
|
15 |
+
test_paths = glob(str(test_dir/Path("*")))
|
16 |
+
|
17 |
+
len(trainval_paths),len(test_paths)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
train_paths,val_paths = sklearn.model_selection.train_test_split(trainval_paths,
|
22 |
+
train_size=0.8,
|
23 |
+
random_state=324)
|
24 |
+
|
25 |
+
print("train|val split:",len(train_paths),"|",len(val_paths))
|
26 |
+
|
27 |
+
train_ds = get_ds(train_paths,bs=config.batch_size,shuffle=config.shuffle)
|
28 |
+
val_ds = get_ds(val_paths,bs=config.batch_size,shuffle=False,is_val=True)
|
29 |
+
test_ds = get_ds(test_paths,bs=config.batch_size,shuffle=False,is_val=True)
|
30 |
+
|
31 |
+
return train_ds,val_ds,test_ds
|
32 |
+
|
33 |
+
|
34 |
+
# def test_dataset():
|
35 |
+
# train_ds = get_ds(train_paths,shuffle=False)
|
36 |
+
# L_batch,AB_batch = next(iter(train_ds))
|
37 |
+
# L_batch = L_batch.numpy()
|
38 |
+
# AB_batch = AB_batch.numpy()
|
39 |
+
# print("L:",L_batch.min(),L_batch.max())
|
40 |
+
# print("A:",AB_batch[:,:,:,0].min(),AB_batch[:,:,:,0].max())
|
41 |
+
# print("B:",AB_batch[:,:,:,1].min(),AB_batch[:,:,:,1].max())
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def tf_RGB_TO_LAB(image):
|
46 |
+
def f(image):
|
47 |
+
image = rgb2lab(image)
|
48 |
+
return image
|
49 |
+
lab = tf.numpy_function(f,[image],tf.float32)
|
50 |
+
lab.set_shape(image.shape)
|
51 |
+
return lab
|
52 |
+
|
53 |
+
|
54 |
+
# load the image in lab space and split the l and ab channels
|
55 |
+
def load_img(img_path):
|
56 |
+
img_bytes = tf.io.read_file(img_path)
|
57 |
+
image = tf.image.decode_image(img_bytes,3,expand_animations=False)
|
58 |
+
image = tf.image.resize(image,[config.image_size,config.image_size])
|
59 |
+
image = image / 255.0
|
60 |
+
image = tf_RGB_TO_LAB(image)
|
61 |
+
|
62 |
+
L,AB = image[:,:,0:1],image[:,:,1:]
|
63 |
+
L,AB = scale_L(L),scale_AB(AB)
|
64 |
+
return L,AB
|
65 |
+
|
66 |
+
def get_ds(image_paths,bs=8,shuffle=False,is_val=False):
|
67 |
+
ds = tf.data.Dataset.from_tensor_slices(image_paths)
|
68 |
+
if shuffle: ds = ds.shuffle(len(image_paths))
|
69 |
+
ds = ds.map(load_img,num_parallel_calls=tf.data.AUTOTUNE)
|
70 |
+
ds = ds.batch(bs,num_parallel_calls=tf.data.AUTOTUNE,drop_remainder=not is_val)
|
71 |
+
|
72 |
+
return ds
|
73 |
+
|
src/data/make_dataset.py
CHANGED
@@ -1,30 +1,128 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
import
|
|
|
|
|
4 |
from pathlib import Path
|
5 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
@click.argument('input_filepath', type=click.Path(exists=True))
|
10 |
-
@click.argument('output_filepath', type=click.Path())
|
11 |
-
def main(input_filepath, output_filepath):
|
12 |
-
""" Runs data processing scripts to turn raw data from (../raw) into
|
13 |
-
cleaned data ready to be analyzed (saved in ../processed).
|
14 |
"""
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
if
|
20 |
-
|
21 |
-
logging.basicConfig(level=logging.INFO, format=log_fmt)
|
22 |
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
-
main()
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
import os,sys;sys.path.append(os.getcwd())
|
3 |
+
from src import config
|
4 |
+
from src.utils import *
|
5 |
+
import argparse
|
6 |
from pathlib import Path
|
7 |
+
from zipfile import ZipFile
|
8 |
+
from glob import glob
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from tqdm import tqdm
|
13 |
+
import shutil
|
14 |
+
from src.data.visualize_dataset import visualize_dataset
|
15 |
|
16 |
+
def download_dataset():
|
17 |
+
"""Used to download dataset from hugging face
|
|
|
|
|
|
|
|
|
|
|
18 |
"""
|
19 |
+
print_title(f"Downloading {config.dataset} dataset from hugging face")
|
20 |
+
snapshot_download(repo_id="Anuj-Panthri/Image-Colorization-Datasets",
|
21 |
+
repo_type="dataset",
|
22 |
+
local_dir=config.raw_dataset_dir,
|
23 |
+
allow_patterns=f"{config.dataset}/*")
|
24 |
+
|
25 |
+
|
26 |
+
def unzip_dataset():
|
27 |
+
print_title(f"Unzipping dataset")
|
28 |
+
print("Extracting to :",Path(config.interim_dataset_dir)/Path("trainval/"))
|
29 |
+
with ZipFile(Path(config.raw_dataset_dir)/Path(f"{config.dataset}/trainval.zip"),"r") as zip:
|
30 |
+
zip.extractall(Path(config.interim_dataset_dir)/Path("trainval/"))
|
31 |
+
|
32 |
+
print("Extracting to :",Path(config.interim_dataset_dir)/Path("test/"))
|
33 |
+
with ZipFile(Path(config.raw_dataset_dir)/Path(f"{config.dataset}/test.zip"),"r") as zip:
|
34 |
+
zip.extractall(Path(config.interim_dataset_dir)/Path("test/"))
|
35 |
+
|
36 |
+
|
37 |
+
def clean_dataset():
|
38 |
+
print_title("CLEANING DATASET")
|
39 |
+
trainval_dir = Path(config.interim_dataset_dir) / Path("trainval/")
|
40 |
+
test_dir = Path(config.interim_dataset_dir) / Path("test/")
|
41 |
+
|
42 |
+
trainval_paths = glob(str(trainval_dir/Path("*")))
|
43 |
+
test_paths = glob(str(test_dir/Path("*")))
|
44 |
+
|
45 |
+
print("train,test: ",len(trainval_paths),",",len(test_paths),sep="")
|
46 |
+
|
47 |
+
|
48 |
+
def clean(image_paths,destination_dir):
|
49 |
+
if os.path.exists(destination_dir): shutil.rmtree(destination_dir)
|
50 |
+
os.makedirs(destination_dir)
|
51 |
+
for i in tqdm(range(len(image_paths))):
|
52 |
+
img = cv2.imread(image_paths[i])
|
53 |
+
img = cv2.resize(img,[128,128])
|
54 |
+
if not is_bw(img):
|
55 |
+
shutil.copy(trainval_paths[i],
|
56 |
+
destination_dir)
|
57 |
+
print("saved to:",destination_dir)
|
58 |
+
|
59 |
+
destination_dir = Path(config.processed_dataset_dir)/Path("trainval/")
|
60 |
+
clean(trainval_paths,destination_dir)
|
61 |
+
|
62 |
+
destination_dir = Path(config.processed_dataset_dir)/Path("test/")
|
63 |
+
clean(test_paths,destination_dir)
|
64 |
+
|
65 |
+
trainval_dir = Path(config.processed_dataset_dir) / Path("trainval/")
|
66 |
+
test_dir = Path(config.processed_dataset_dir) / Path("test/")
|
67 |
+
|
68 |
+
trainval_paths = glob(str(trainval_dir/Path("*")))
|
69 |
+
test_paths = glob(str(test_dir/Path("*")))
|
70 |
|
71 |
+
print("after cleaning train,test: ",len(trainval_paths),",",len(test_paths),sep="")
|
72 |
+
|
73 |
+
|
74 |
+
def prepare_dataset():
|
75 |
+
print_title(f"Preparing dataset")
|
76 |
+
download_dataset()
|
77 |
+
unzip_dataset()
|
78 |
+
clean_dataset()
|
79 |
+
|
80 |
+
def delete_cache():
|
81 |
+
## clean old interim and raw datasets
|
82 |
+
print_title("deleting unused raw and interim dataset dirs")
|
83 |
+
if os.path.exists(config.raw_dataset_dir):
|
84 |
+
shutil.rmtree(config.raw_dataset_dir)
|
85 |
+
if os.path.exists(config.interim_dataset_dir):
|
86 |
+
shutil.rmtree(config.interim_dataset_dir)
|
87 |
+
|
88 |
+
def delete_all():
|
89 |
+
## clean all datasets
|
90 |
+
print_title("deleting all dataset dirs")
|
91 |
+
if os.path.exists(config.raw_dataset_dir):
|
92 |
+
shutil.rmtree(config.raw_dataset_dir)
|
93 |
+
if os.path.exists(config.interim_dataset_dir):
|
94 |
+
shutil.rmtree(config.interim_dataset_dir)
|
95 |
+
if os.path.exists(config.processed_dataset_dir):
|
96 |
+
shutil.rmtree(config.processed_dataset_dir)
|
97 |
+
|
98 |
+
|
99 |
+
if __name__=="__main__":
|
100 |
+
parser = argparse.ArgumentParser()
|
101 |
+
parser.add_argument("command")
|
102 |
+
parser.add_argument("-d","--dataset",default="forests")
|
103 |
+
parser.add_argument("--cache",action="store_true",default=True)
|
104 |
+
parser.add_argument("--all",action="store_true")
|
105 |
+
|
106 |
+
"""
|
107 |
+
prepare dataset: data prepare
|
108 |
+
visualize dataset: data show
|
109 |
+
delete raw & interim dataset dir: data delete --cache
|
110 |
+
delete all dataset dir: data delete --all
|
111 |
+
"""
|
112 |
+
|
113 |
+
args = parser.parse_args()
|
114 |
+
# print(args)
|
115 |
|
116 |
+
if args.command=="prepare":
|
117 |
+
prepare_dataset()
|
|
|
118 |
|
119 |
+
elif args.command=="show":
|
120 |
+
visualize_dataset()
|
121 |
|
122 |
+
elif args.command=="delete":
|
123 |
+
if(args.all): delete_all()
|
124 |
+
elif(args.cache): delete_cache()
|
125 |
+
|
126 |
+
else:
|
127 |
+
print("unsupported")
|
128 |
|
|
src/data/visualize_dataset.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os,sys;sys.path.append(os.getcwd())
|
2 |
+
from src.data.load_dataset import get_ds,get_datasets
|
3 |
+
from src import config
|
4 |
+
from src.utils import *
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import cv2
|
7 |
+
import math
|
8 |
+
|
9 |
+
def see_batch(L_batch,AB_batch,show_L=False,cols=4,row_size=5,col_size=5,title=None):
|
10 |
+
n = L_batch.shape[0]
|
11 |
+
rows = math.ceil(n/cols)
|
12 |
+
fig = plt.figure(figsize=(col_size*cols,row_size*rows))
|
13 |
+
if title:
|
14 |
+
plt.title(title)
|
15 |
+
plt.axis("off")
|
16 |
+
|
17 |
+
for i in range(n):
|
18 |
+
fig.add_subplot(rows,cols,i+1)
|
19 |
+
L,AB = L_batch[i],AB_batch[i]
|
20 |
+
L,AB = rescale_L(L), rescale_AB(AB)
|
21 |
+
# print(L.shape,AB.shape)
|
22 |
+
img = np.concatenate([L,AB],axis=-1)
|
23 |
+
img = cv2.cvtColor(img,cv2.COLOR_LAB2RGB)*255
|
24 |
+
# print(img.min(),img.max())
|
25 |
+
if show_L:
|
26 |
+
L = np.tile(L,(1,1,3))/100*255
|
27 |
+
img = np.concatenate([L,img],axis=1)
|
28 |
+
plt.imshow(img.astype("uint8"))
|
29 |
+
plt.show()
|
30 |
+
|
31 |
+
|
32 |
+
def visualize_dataset():
|
33 |
+
train_ds,val_ds,test_ds = get_datasets()
|
34 |
+
L_batch,AB_batch = next(iter(train_ds))
|
35 |
+
L_batch,AB_batch = L_batch.numpy(), AB_batch.numpy()
|
36 |
+
see_batch(L_batch,
|
37 |
+
AB_batch,
|
38 |
+
title="training dataset")
|
39 |
+
|
40 |
+
L_batch,AB_batch = next(iter(val_ds))
|
41 |
+
L_batch,AB_batch = L_batch.numpy(), AB_batch.numpy()
|
42 |
+
see_batch(L_batch,
|
43 |
+
AB_batch,
|
44 |
+
title="validation dataset")
|
45 |
+
|
46 |
+
L_batch,AB_batch = next(iter(test_ds))
|
47 |
+
L_batch,AB_batch = L_batch.numpy(), AB_batch.numpy()
|
48 |
+
see_batch(L_batch,
|
49 |
+
AB_batch,
|
50 |
+
title="testing dataset")
|
51 |
+
|
52 |
+
|
src/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class Config:
|
5 |
+
def __init__(self,path="config.yaml"):
|
6 |
+
with open(path,'r') as f:
|
7 |
+
self.config = yaml.safe_load(f)
|
8 |
+
|
9 |
+
def __str__(self):
|
10 |
+
return str(self.config)
|
11 |
+
|
12 |
+
def __getattr__(self, name: str):
|
13 |
+
return self.config.get(name)
|
14 |
+
|
15 |
+
# def __setattr__(self, name: str, value: any):
|
16 |
+
# self.config[name]=value
|
17 |
+
|
18 |
+
def is_bw(img):
|
19 |
+
rg,gb,rb = img[:,:,0]-img[:,:,1] , img[:,:,1]-img[:,:,2] , img[:,:,0]-img[:,:,2]
|
20 |
+
rg,gb,rb = np.abs(rg).sum(),np.abs(gb).sum(),np.abs(rb).sum()
|
21 |
+
avg = np.mean([rg,gb,rb])
|
22 |
+
# print(rg,gb,rb)
|
23 |
+
|
24 |
+
return avg<10
|
25 |
+
|
26 |
+
def print_title(msg:str,n=30):
|
27 |
+
print("="*n,msg.upper(),"="*n,sep="")
|
28 |
+
|
29 |
+
def scale_L(L):
|
30 |
+
return L/100
|
31 |
+
def rescale_L(L):
|
32 |
+
return L*100
|
33 |
+
|
34 |
+
def scale_AB(AB):
|
35 |
+
return AB/128
|
36 |
+
|
37 |
+
def rescale_AB(AB):
|
38 |
+
return AB*128
|
39 |
+
|