chicelli's picture
Upload 21 files
9f68e7c verified
raw
history blame contribute delete
791 Bytes
import os
import numpy as np
from img2art_search.constants import BASE_PATH
def get_data_from_local() -> np.ndarray:
left_or_top_data = [
f"{BASE_PATH}/splits/left_or_top/{fn}"
for fn in os.listdir(f"{BASE_PATH}/splits/left_or_top")
]
x = np.array(left_or_top_data)
y = np.array([ex.replace("left_or_top", "right_or_bottom") for ex in x])
data = np.array([x, y])
return data
def split_train_val_test(data: np.ndarray, test_size: float, val_size: float) -> tuple:
train_size = 1 - test_size - val_size
SPLIT = int(data.shape[1] * train_size)
TEST_SPLIT = SPLIT + int(data.shape[1] * test_size)
train = data[:, :SPLIT]
validation = data[:, SPLIT:TEST_SPLIT]
test = data[:, TEST_SPLIT:]
return train, validation, test