Spaces:
Runtime error
Runtime error
Niv Sardi
commited on
Commit
·
b1d65c2
1
Parent(s):
1732876
split dataset into train test val
Browse files- python/common/defaults.py +1 -0
- python/split.py +43 -0
- run.sh +2 -0
python/common/defaults.py
CHANGED
@@ -26,3 +26,4 @@ AUGMENTED_LABELS_PATH = D('AUGMENTED_LABELS_PATH', f'{AUGMENTED_DATA_PATH}/label
|
|
26 |
AUGMENTED_IMAGES_PATH = D('AUGMENTED_IMAGES_PATH', f'{AUGMENTED_DATA_PATH}/images')
|
27 |
|
28 |
MAIN_CSV_PATH = D('MAIN_CSV_PATH', f'{DATA_PATH}/entities.csv')
|
|
|
|
26 |
AUGMENTED_IMAGES_PATH = D('AUGMENTED_IMAGES_PATH', f'{AUGMENTED_DATA_PATH}/images')
|
27 |
|
28 |
MAIN_CSV_PATH = D('MAIN_CSV_PATH', f'{DATA_PATH}/entities.csv')
|
29 |
+
SPLIT_DATA_PATH = D('SPLIT_DATA_PATH', f'{DATA_PATH}/split')
|
python/split.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
from common import defaults, mkdir
|
5 |
+
|
6 |
+
if __name__ == '__main__':
|
7 |
+
import argparse
|
8 |
+
parser = argparse.ArgumentParser(description='splits a yolo dataset between different data partitions')
|
9 |
+
parser.add_argument('datapath', metavar='datapath', type=str,
|
10 |
+
help='csv file', default=defaults.SQUARES_DATA_PATH)
|
11 |
+
parser.add_argument('--partitions', metavar='partitions', type=str, nargs='+',
|
12 |
+
help='data path', default=['train:0.8', 'val:0.1', 'test:0.1'])
|
13 |
+
parser.add_argument('--dest', metavar='dest', type=str,
|
14 |
+
help='dest path', default=defaults.SPLIT_DATA_PATH)
|
15 |
+
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
def image_to_label(i):
|
19 |
+
l = i.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
|
20 |
+
if os.path.exists(l):
|
21 |
+
return l
|
22 |
+
return None
|
23 |
+
|
24 |
+
images = [d for d in os.scandir(os.path.join(args.datapath, 'images'))]
|
25 |
+
|
26 |
+
np = -1
|
27 |
+
for d,r in [a.split(':') for a in args.partitions]:
|
28 |
+
p = np + 1
|
29 |
+
np = min(p + math.floor(len(images)*float(r)), len(images))
|
30 |
+
|
31 |
+
cpi = os.path.join(args.dest, d, 'images')
|
32 |
+
cpl = os.path.join(args.dest, d, 'labels')
|
33 |
+
rpi = os.path.relpath(os.path.join(args.datapath, 'images'), cpi)
|
34 |
+
rpl = os.path.relpath(os.path.join(args.datapath, 'labels'), cpl)
|
35 |
+
|
36 |
+
mkdir.make_dirs([cpi, cpl])
|
37 |
+
print( f'{d:6s} [ {p:6d}, {np:6d} ] ({np-p:6d}:{(np-p)/len(images):0.2f} )')
|
38 |
+
for si in images[p:np]:
|
39 |
+
l = image_to_label(si.path)
|
40 |
+
os.symlink(os.path.join(rpi, si.name), os.path.join(cpi, si.name))
|
41 |
+
if l:
|
42 |
+
nl = os.path.basename(l)
|
43 |
+
os.symlink(os.path.join(rpl, nl), os.path.join(cpl, nl))
|
run.sh
CHANGED
@@ -15,5 +15,7 @@ echo "✨ augmenting data"
|
|
15 |
${PY} ./python/augment.py
|
16 |
echo "🖼 croping augmented data"
|
17 |
${PY} ./python/crop.py ./data/augmented/images
|
|
|
|
|
18 |
echo "🧠 train model"
|
19 |
sh train.sh
|
|
|
15 |
${PY} ./python/augment.py
|
16 |
echo "🖼 croping augmented data"
|
17 |
${PY} ./python/crop.py ./data/augmented/images
|
18 |
+
echo "✂ split dataset into train, val and test groups"
|
19 |
+
${PY} ./python/split.py ./data/squares/
|
20 |
echo "🧠 train model"
|
21 |
sh train.sh
|