Anuj-Panthri commited on
Commit
922c280
·
1 Parent(s): b03b420

added dataset scripts

Browse files
Makefile CHANGED
@@ -1,144 +1,2 @@
1
  .PHONY: clean data lint requirements sync_data_to_s3 sync_data_from_s3
2
 
3
- #################################################################################
4
- # GLOBALS #
5
- #################################################################################
6
-
7
- PROJECT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
8
- BUCKET = [OPTIONAL] your-bucket-for-syncing-data (do not include 's3://')
9
- PROFILE = default
10
- PROJECT_NAME = project_name
11
- PYTHON_INTERPRETER = python
12
-
13
- ifeq (,$(shell which conda))
14
- HAS_CONDA=False
15
- else
16
- HAS_CONDA=True
17
- endif
18
-
19
- #################################################################################
20
- # COMMANDS #
21
- #################################################################################
22
-
23
- ## Install Python Dependencies
24
- requirements: test_environment
25
- $(PYTHON_INTERPRETER) -m pip install -U pip setuptools wheel
26
- $(PYTHON_INTERPRETER) -m pip install -r requirements.txt
27
-
28
- ## Make Dataset
29
- data: requirements
30
- $(PYTHON_INTERPRETER) src/data/make_dataset.py data/raw data/processed
31
-
32
- ## Delete all compiled Python files
33
- clean:
34
- find . -type f -name "*.py[co]" -delete
35
- find . -type d -name "__pycache__" -delete
36
-
37
- ## Lint using flake8
38
- lint:
39
- flake8 src
40
-
41
- ## Upload Data to S3
42
- sync_data_to_s3:
43
- ifeq (default,$(PROFILE))
44
- aws s3 sync data/ s3://$(BUCKET)/data/
45
- else
46
- aws s3 sync data/ s3://$(BUCKET)/data/ --profile $(PROFILE)
47
- endif
48
-
49
- ## Download Data from S3
50
- sync_data_from_s3:
51
- ifeq (default,$(PROFILE))
52
- aws s3 sync s3://$(BUCKET)/data/ data/
53
- else
54
- aws s3 sync s3://$(BUCKET)/data/ data/ --profile $(PROFILE)
55
- endif
56
-
57
- ## Set up python interpreter environment
58
- create_environment:
59
- ifeq (True,$(HAS_CONDA))
60
- @echo ">>> Detected conda, creating conda environment."
61
- ifeq (3,$(findstring 3,$(PYTHON_INTERPRETER)))
62
- conda create --name $(PROJECT_NAME) python=3
63
- else
64
- conda create --name $(PROJECT_NAME) python=2.7
65
- endif
66
- @echo ">>> New conda env created. Activate with:\nsource activate $(PROJECT_NAME)"
67
- else
68
- $(PYTHON_INTERPRETER) -m pip install -q virtualenv virtualenvwrapper
69
- @echo ">>> Installing virtualenvwrapper if not already installed.\nMake sure the following lines are in shell startup file\n\
70
- export WORKON_HOME=$$HOME/.virtualenvs\nexport PROJECT_HOME=$$HOME/Devel\nsource /usr/local/bin/virtualenvwrapper.sh\n"
71
- @bash -c "source `which virtualenvwrapper.sh`;mkvirtualenv $(PROJECT_NAME) --python=$(PYTHON_INTERPRETER)"
72
- @echo ">>> New virtualenv created. Activate with:\nworkon $(PROJECT_NAME)"
73
- endif
74
-
75
- ## Test python environment is setup correctly
76
- test_environment:
77
- $(PYTHON_INTERPRETER) test_environment.py
78
-
79
- #################################################################################
80
- # PROJECT RULES #
81
- #################################################################################
82
-
83
-
84
-
85
- #################################################################################
86
- # Self Documenting Commands #
87
- #################################################################################
88
-
89
- .DEFAULT_GOAL := help
90
-
91
- # Inspired by <http://marmelab.com/blog/2016/02/29/auto-documented-makefile.html>
92
- # sed script explained:
93
- # /^##/:
94
- # * save line in hold space
95
- # * purge line
96
- # * Loop:
97
- # * append newline + line to hold space
98
- # * go to next line
99
- # * if line starts with doc comment, strip comment character off and loop
100
- # * remove target prerequisites
101
- # * append hold space (+ newline) to line
102
- # * replace newline plus comments by `---`
103
- # * print line
104
- # Separate expressions are necessary because labels cannot be delimited by
105
- # semicolon; see <http://stackoverflow.com/a/11799865/1968>
106
- .PHONY: help
107
- help:
108
- @echo "$$(tput bold)Available rules:$$(tput sgr0)"
109
- @echo
110
- @sed -n -e "/^## / { \
111
- h; \
112
- s/.*//; \
113
- :doc" \
114
- -e "H; \
115
- n; \
116
- s/^## //; \
117
- t doc" \
118
- -e "s/:.*//; \
119
- G; \
120
- s/\\n## /---/; \
121
- s/\\n/ /g; \
122
- p; \
123
- }" ${MAKEFILE_LIST} \
124
- | LC_ALL='C' sort --ignore-case \
125
- | awk -F '---' \
126
- -v ncol=$$(tput cols) \
127
- -v indent=19 \
128
- -v col_on="$$(tput setaf 6)" \
129
- -v col_off="$$(tput sgr0)" \
130
- '{ \
131
- printf "%s%*s%s ", col_on, -indent, $$1, col_off; \
132
- n = split($$2, words, " "); \
133
- line_length = ncol - indent; \
134
- for (i = 1; i <= n; i++) { \
135
- line_length -= length(words[i]) + 1; \
136
- if (line_length <= 0) { \
137
- line_length = ncol - indent - length(words[i]) - 1; \
138
- printf "\n%*s ", -indent, " "; \
139
- } \
140
- printf "%s ", words[i]; \
141
- } \
142
- printf "\n"; \
143
- }' \
144
- | more $(shell test $(shell uname) = Darwin && echo '--no-init --raw-control-chars')
 
1
  .PHONY: clean data lint requirements sync_data_to_s3 sync_data_from_s3
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
command.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+
5
+ # parser = argparse.ArgumentParser()
6
+ # parser.add_argument("category")
7
+ # parser.add_argument("subcommand-args")
8
+ # args = parser.parse_args()
9
+ args = sys.argv
10
+
11
+ # remove "command.py"
12
+ args = args[1:]
13
+
14
+ # print(args)
15
+ subcommand = args[0].lower()
16
+
17
+ subcommand_args = " ".join(args[1:])
18
+ if subcommand=="data":
19
+ command = "py src/data/make_dataset.py "+subcommand_args
20
+ # print(command)
21
+ os.system(command)
22
+ else:
23
+ print("subcommand not supported.")
24
+
25
+ # os.system("py src/__init__.py")
26
+ """
27
+ download the dataset: data download
28
+ preprocess dataset: data prepare
29
+ visualize dataset: data show
30
+ delete raw & interim dataset dir: data delete --cache
31
+ delete all dataset dir: data delete --all
32
+
33
+
34
+ train model: model train
35
+ evaluate model: model evaluate
36
+ inference with model: model predict --image test.jpg --folder images/ -d results/
37
+
38
+
39
+
40
+ """
config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ raw_dataset_dir: data/raw/
2
+ interim_dataset_dir: data/interim/
3
+ processed_dataset_dir: data/processed/
4
+
5
+ # forests or pascal-voc
6
+ dataset: forests
7
+
8
+ image_size: 224
9
+ train_size: 0.8
10
+ shuffle: False
11
+ batch_size: 16
12
+
13
+ seed: 324
requirements.txt CHANGED
@@ -1,14 +1,3 @@
1
- # local package
2
- -e .
3
-
4
- # external requirements
5
- click
6
- Sphinx
7
- coverage
8
- awscli
9
- flake8
10
- python-dotenv>=0.5.1
11
-
12
-
13
- # backwards compatibility
14
- pathlib2
 
1
+ huggingface_hub
2
+ comet_ml
3
+ scikit-image
 
 
 
 
 
 
 
 
 
 
 
src/__init__.py CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils import Config
2
+ from pathlib import Path
3
+
4
+ config = Config("config.yaml")
5
+ # config.raw_dataset_dir = Path(config.raw_dataset_dir)
6
+ # config.interim_dataset_dir = Path(config.interim_dataset_dir)
7
+ # config.processed_dataset_dir = Path(config.processed_dataset_dir)
8
+
9
+ # print(config)
src/data/load_dataset.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys;sys.path.append(os.getcwd())
2
+ import tensorflow as tf
3
+ from src import config
4
+ from src.utils import *
5
+ from pathlib import Path
6
+ from glob import glob
7
+ import sklearn.model_selection
8
+ from skimage.color import rgb2lab, lab2rgb
9
+
10
+ def get_datasets():
11
+ trainval_dir = Path(config.processed_dataset_dir) / Path("trainval/")
12
+ test_dir = Path(config.processed_dataset_dir) / Path("test/")
13
+
14
+ trainval_paths = glob(str(trainval_dir/Path("*")))
15
+ test_paths = glob(str(test_dir/Path("*")))
16
+
17
+ len(trainval_paths),len(test_paths)
18
+
19
+
20
+
21
+ train_paths,val_paths = sklearn.model_selection.train_test_split(trainval_paths,
22
+ train_size=0.8,
23
+ random_state=324)
24
+
25
+ print("train|val split:",len(train_paths),"|",len(val_paths))
26
+
27
+ train_ds = get_ds(train_paths,bs=config.batch_size,shuffle=config.shuffle)
28
+ val_ds = get_ds(val_paths,bs=config.batch_size,shuffle=False,is_val=True)
29
+ test_ds = get_ds(test_paths,bs=config.batch_size,shuffle=False,is_val=True)
30
+
31
+ return train_ds,val_ds,test_ds
32
+
33
+
34
+ # def test_dataset():
35
+ # train_ds = get_ds(train_paths,shuffle=False)
36
+ # L_batch,AB_batch = next(iter(train_ds))
37
+ # L_batch = L_batch.numpy()
38
+ # AB_batch = AB_batch.numpy()
39
+ # print("L:",L_batch.min(),L_batch.max())
40
+ # print("A:",AB_batch[:,:,:,0].min(),AB_batch[:,:,:,0].max())
41
+ # print("B:",AB_batch[:,:,:,1].min(),AB_batch[:,:,:,1].max())
42
+
43
+
44
+
45
+ def tf_RGB_TO_LAB(image):
46
+ def f(image):
47
+ image = rgb2lab(image)
48
+ return image
49
+ lab = tf.numpy_function(f,[image],tf.float32)
50
+ lab.set_shape(image.shape)
51
+ return lab
52
+
53
+
54
+ # load the image in lab space and split the l and ab channels
55
+ def load_img(img_path):
56
+ img_bytes = tf.io.read_file(img_path)
57
+ image = tf.image.decode_image(img_bytes,3,expand_animations=False)
58
+ image = tf.image.resize(image,[config.image_size,config.image_size])
59
+ image = image / 255.0
60
+ image = tf_RGB_TO_LAB(image)
61
+
62
+ L,AB = image[:,:,0:1],image[:,:,1:]
63
+ L,AB = scale_L(L),scale_AB(AB)
64
+ return L,AB
65
+
66
+ def get_ds(image_paths,bs=8,shuffle=False,is_val=False):
67
+ ds = tf.data.Dataset.from_tensor_slices(image_paths)
68
+ if shuffle: ds = ds.shuffle(len(image_paths))
69
+ ds = ds.map(load_img,num_parallel_calls=tf.data.AUTOTUNE)
70
+ ds = ds.batch(bs,num_parallel_calls=tf.data.AUTOTUNE,drop_remainder=not is_val)
71
+
72
+ return ds
73
+
src/data/make_dataset.py CHANGED
@@ -1,30 +1,128 @@
1
- # -*- coding: utf-8 -*-
2
- import click
3
- import logging
 
 
4
  from pathlib import Path
5
- from dotenv import find_dotenv, load_dotenv
 
 
 
 
 
 
 
6
 
7
-
8
- @click.command()
9
- @click.argument('input_filepath', type=click.Path(exists=True))
10
- @click.argument('output_filepath', type=click.Path())
11
- def main(input_filepath, output_filepath):
12
- """ Runs data processing scripts to turn raw data from (../raw) into
13
- cleaned data ready to be analyzed (saved in ../processed).
14
  """
15
- logger = logging.getLogger(__name__)
16
- logger.info('making final data set from raw data')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- if __name__ == '__main__':
20
- log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
21
- logging.basicConfig(level=logging.INFO, format=log_fmt)
22
 
23
- # not used in this stub but often useful for finding various files
24
- project_dir = Path(__file__).resolve().parents[2]
25
 
26
- # find .env automagically by walking up directories until it's found, then
27
- # load up the .env entries as environment variables
28
- load_dotenv(find_dotenv())
 
 
 
29
 
30
- main()
 
1
+ from huggingface_hub import snapshot_download
2
+ import os,sys;sys.path.append(os.getcwd())
3
+ from src import config
4
+ from src.utils import *
5
+ import argparse
6
  from pathlib import Path
7
+ from zipfile import ZipFile
8
+ from glob import glob
9
+ import cv2
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ from tqdm import tqdm
13
+ import shutil
14
+ from src.data.visualize_dataset import visualize_dataset
15
 
16
+ def download_dataset():
17
+ """Used to download dataset from hugging face
 
 
 
 
 
18
  """
19
+ print_title(f"Downloading {config.dataset} dataset from hugging face")
20
+ snapshot_download(repo_id="Anuj-Panthri/Image-Colorization-Datasets",
21
+ repo_type="dataset",
22
+ local_dir=config.raw_dataset_dir,
23
+ allow_patterns=f"{config.dataset}/*")
24
+
25
+
26
+ def unzip_dataset():
27
+ print_title(f"Unzipping dataset")
28
+ print("Extracting to :",Path(config.interim_dataset_dir)/Path("trainval/"))
29
+ with ZipFile(Path(config.raw_dataset_dir)/Path(f"{config.dataset}/trainval.zip"),"r") as zip:
30
+ zip.extractall(Path(config.interim_dataset_dir)/Path("trainval/"))
31
+
32
+ print("Extracting to :",Path(config.interim_dataset_dir)/Path("test/"))
33
+ with ZipFile(Path(config.raw_dataset_dir)/Path(f"{config.dataset}/test.zip"),"r") as zip:
34
+ zip.extractall(Path(config.interim_dataset_dir)/Path("test/"))
35
+
36
+
37
+ def clean_dataset():
38
+ print_title("CLEANING DATASET")
39
+ trainval_dir = Path(config.interim_dataset_dir) / Path("trainval/")
40
+ test_dir = Path(config.interim_dataset_dir) / Path("test/")
41
+
42
+ trainval_paths = glob(str(trainval_dir/Path("*")))
43
+ test_paths = glob(str(test_dir/Path("*")))
44
+
45
+ print("train,test: ",len(trainval_paths),",",len(test_paths),sep="")
46
+
47
+
48
+ def clean(image_paths,destination_dir):
49
+ if os.path.exists(destination_dir): shutil.rmtree(destination_dir)
50
+ os.makedirs(destination_dir)
51
+ for i in tqdm(range(len(image_paths))):
52
+ img = cv2.imread(image_paths[i])
53
+ img = cv2.resize(img,[128,128])
54
+ if not is_bw(img):
55
+ shutil.copy(trainval_paths[i],
56
+ destination_dir)
57
+ print("saved to:",destination_dir)
58
+
59
+ destination_dir = Path(config.processed_dataset_dir)/Path("trainval/")
60
+ clean(trainval_paths,destination_dir)
61
+
62
+ destination_dir = Path(config.processed_dataset_dir)/Path("test/")
63
+ clean(test_paths,destination_dir)
64
+
65
+ trainval_dir = Path(config.processed_dataset_dir) / Path("trainval/")
66
+ test_dir = Path(config.processed_dataset_dir) / Path("test/")
67
+
68
+ trainval_paths = glob(str(trainval_dir/Path("*")))
69
+ test_paths = glob(str(test_dir/Path("*")))
70
 
71
+ print("after cleaning train,test: ",len(trainval_paths),",",len(test_paths),sep="")
72
+
73
+
74
+ def prepare_dataset():
75
+ print_title(f"Preparing dataset")
76
+ download_dataset()
77
+ unzip_dataset()
78
+ clean_dataset()
79
+
80
+ def delete_cache():
81
+ ## clean old interim and raw datasets
82
+ print_title("deleting unused raw and interim dataset dirs")
83
+ if os.path.exists(config.raw_dataset_dir):
84
+ shutil.rmtree(config.raw_dataset_dir)
85
+ if os.path.exists(config.interim_dataset_dir):
86
+ shutil.rmtree(config.interim_dataset_dir)
87
+
88
+ def delete_all():
89
+ ## clean all datasets
90
+ print_title("deleting all dataset dirs")
91
+ if os.path.exists(config.raw_dataset_dir):
92
+ shutil.rmtree(config.raw_dataset_dir)
93
+ if os.path.exists(config.interim_dataset_dir):
94
+ shutil.rmtree(config.interim_dataset_dir)
95
+ if os.path.exists(config.processed_dataset_dir):
96
+ shutil.rmtree(config.processed_dataset_dir)
97
+
98
+
99
+ if __name__=="__main__":
100
+ parser = argparse.ArgumentParser()
101
+ parser.add_argument("command")
102
+ parser.add_argument("-d","--dataset",default="forests")
103
+ parser.add_argument("--cache",action="store_true",default=True)
104
+ parser.add_argument("--all",action="store_true")
105
+
106
+ """
107
+ prepare dataset: data prepare
108
+ visualize dataset: data show
109
+ delete raw & interim dataset dir: data delete --cache
110
+ delete all dataset dir: data delete --all
111
+ """
112
+
113
+ args = parser.parse_args()
114
+ # print(args)
115
 
116
+ if args.command=="prepare":
117
+ prepare_dataset()
 
118
 
119
+ elif args.command=="show":
120
+ visualize_dataset()
121
 
122
+ elif args.command=="delete":
123
+ if(args.all): delete_all()
124
+ elif(args.cache): delete_cache()
125
+
126
+ else:
127
+ print("unsupported")
128
 
 
src/data/visualize_dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys;sys.path.append(os.getcwd())
2
+ from src.data.load_dataset import get_ds,get_datasets
3
+ from src import config
4
+ from src.utils import *
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+ import math
8
+
9
+ def see_batch(L_batch,AB_batch,show_L=False,cols=4,row_size=5,col_size=5,title=None):
10
+ n = L_batch.shape[0]
11
+ rows = math.ceil(n/cols)
12
+ fig = plt.figure(figsize=(col_size*cols,row_size*rows))
13
+ if title:
14
+ plt.title(title)
15
+ plt.axis("off")
16
+
17
+ for i in range(n):
18
+ fig.add_subplot(rows,cols,i+1)
19
+ L,AB = L_batch[i],AB_batch[i]
20
+ L,AB = rescale_L(L), rescale_AB(AB)
21
+ # print(L.shape,AB.shape)
22
+ img = np.concatenate([L,AB],axis=-1)
23
+ img = cv2.cvtColor(img,cv2.COLOR_LAB2RGB)*255
24
+ # print(img.min(),img.max())
25
+ if show_L:
26
+ L = np.tile(L,(1,1,3))/100*255
27
+ img = np.concatenate([L,img],axis=1)
28
+ plt.imshow(img.astype("uint8"))
29
+ plt.show()
30
+
31
+
32
+ def visualize_dataset():
33
+ train_ds,val_ds,test_ds = get_datasets()
34
+ L_batch,AB_batch = next(iter(train_ds))
35
+ L_batch,AB_batch = L_batch.numpy(), AB_batch.numpy()
36
+ see_batch(L_batch,
37
+ AB_batch,
38
+ title="training dataset")
39
+
40
+ L_batch,AB_batch = next(iter(val_ds))
41
+ L_batch,AB_batch = L_batch.numpy(), AB_batch.numpy()
42
+ see_batch(L_batch,
43
+ AB_batch,
44
+ title="validation dataset")
45
+
46
+ L_batch,AB_batch = next(iter(test_ds))
47
+ L_batch,AB_batch = L_batch.numpy(), AB_batch.numpy()
48
+ see_batch(L_batch,
49
+ AB_batch,
50
+ title="testing dataset")
51
+
52
+
src/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import numpy as np
3
+
4
+ class Config:
5
+ def __init__(self,path="config.yaml"):
6
+ with open(path,'r') as f:
7
+ self.config = yaml.safe_load(f)
8
+
9
+ def __str__(self):
10
+ return str(self.config)
11
+
12
+ def __getattr__(self, name: str):
13
+ return self.config.get(name)
14
+
15
+ # def __setattr__(self, name: str, value: any):
16
+ # self.config[name]=value
17
+
18
+ def is_bw(img):
19
+ rg,gb,rb = img[:,:,0]-img[:,:,1] , img[:,:,1]-img[:,:,2] , img[:,:,0]-img[:,:,2]
20
+ rg,gb,rb = np.abs(rg).sum(),np.abs(gb).sum(),np.abs(rb).sum()
21
+ avg = np.mean([rg,gb,rb])
22
+ # print(rg,gb,rb)
23
+
24
+ return avg<10
25
+
26
+ def print_title(msg:str,n=30):
27
+ print("="*n,msg.upper(),"="*n,sep="")
28
+
29
+ def scale_L(L):
30
+ return L/100
31
+ def rescale_L(L):
32
+ return L*100
33
+
34
+ def scale_AB(AB):
35
+ return AB/128
36
+
37
+ def rescale_AB(AB):
38
+ return AB*128
39
+