File size: 791 Bytes
9f68e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os

import numpy as np

from img2art_search.constants import BASE_PATH


def get_data_from_local() -> np.ndarray:
    left_or_top_data = [
        f"{BASE_PATH}/splits/left_or_top/{fn}"
        for fn in os.listdir(f"{BASE_PATH}/splits/left_or_top")
    ]

    x = np.array(left_or_top_data)
    y = np.array([ex.replace("left_or_top", "right_or_bottom") for ex in x])

    data = np.array([x, y])
    return data


def split_train_val_test(data: np.ndarray, test_size: float, val_size: float) -> tuple:
    train_size = 1 - test_size - val_size
    SPLIT = int(data.shape[1] * train_size)
    TEST_SPLIT = SPLIT + int(data.shape[1] * test_size)
    train = data[:, :SPLIT]
    validation = data[:, SPLIT:TEST_SPLIT]
    test = data[:, TEST_SPLIT:]
    return train, validation, test